Skip to content

Commit

Permalink
Add limit and offset parse error
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneli committed Nov 27, 2019
1 parent 5940839 commit 23f6840
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Dialect interface {
ModifyColumn(tableName string, columnName string, typ string) error

// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
Expand Down
19 changes: 16 additions & 3 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,23 @@ func (s commonDialect) CurrentDatabase() (name string) {
return
}

func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
// LimitAndOffsetSQL return generated SQL with Limit and Offset
func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
parsedLimit, err := s.parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
parsedOffset, err := s.parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
Expand Down Expand Up @@ -181,6 +190,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri
return indexName, columnName
}

func (commonDialect) parseInt(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}

// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
Expand Down
15 changes: 11 additions & 4 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
Expand Down Expand Up @@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err
return err
}

func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
parsedLimit, err := s.parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)

if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
parsedOffset, err := s.parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
Expand Down
17 changes: 14 additions & 3 deletions dialects/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,25 @@ func (s mssql) CurrentDatabase() (name string) {
return
}

func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
parseInt := func(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
parsedOffset, err := parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
}
}
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
parsedLimit, err := parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"
Expand Down
68 changes: 68 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,74 @@ func TestOffset(t *testing.T) {
}
}

func TestLimitAndOffsetSQL(t *testing.T) {
user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
t.Fatal(err)
}

tests := []struct {
name string
limit, offset interface{}
users []*User
ok bool
}{
{
name: "OK",
limit: float64(2),
offset: float64(2),
users: []*User{
&User{Name: "TestLimitAndOffsetSQL3", Age: 30},
&User{Name: "TestLimitAndOffsetSQL2", Age: 20},
},
ok: true,
},
{
name: "Limit parse error",
limit: float64(1000000), // 1e+06
offset: float64(2),
ok: false,
},
{
name: "Offset parse error",
limit: float64(2),
offset: float64(1000000), // 1e+06
ok: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var users []*User
err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
if tt.ok {
if err != nil {
t.Errorf("error expected nil, but got %v", err)
}
if len(users) != len(tt.users) {
t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
}
for i := range tt.users {
if users[i].Name != tt.users[i].Name {
t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
}
if users[i].Age != tt.users[i].Age {
t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
}
}
} else {
if err == nil {
t.Error("error expected not nil, but got nil")
}
}
})
}
}

func TestOr(t *testing.T) {
user1 := User{Name: "OrUser1", Age: 1}
user2 := User{Name: "OrUser2", Age: 10}
Expand Down
4 changes: 3 additions & 1 deletion scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,9 @@ func (scope *Scope) orderSQL() string {
}

func (scope *Scope) limitAndOffsetSQL() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
scope.Err(err)
return sql
}

func (scope *Scope) groupSQL() string {
Expand Down

0 comments on commit 23f6840

Please sign in to comment.