Skip to content

Commit

Permalink
expression: fix wrong result when select with collation (#18665) (#18735
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ti-srebot authored Jul 27, 2020
1 parent 0529b1b commit 75171f0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 9 deletions.
24 changes: 22 additions & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (b *baseBuiltinFunc) collator() collate.Collator {

func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expression) (baseBuiltinFunc, error) {
if ctx == nil {
panic("ctx should not be nil")
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}
if err := checkIllegalMixCollation(funcName, args); err != nil {
return baseBuiltinFunc{}, err
Expand Down Expand Up @@ -134,7 +134,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
panic("unexpected length of args and argTps")
}
if ctx == nil {
panic("ctx should not be nil")
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}

for i := range args {
Expand Down Expand Up @@ -244,6 +244,26 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
return bf, nil
}

// newBaseBuiltinFuncWithFieldType create BaseBuiltinFunc with FieldType charset and collation.
// do not check and compute collation.
func newBaseBuiltinFuncWithFieldType(ctx sessionctx.Context, tp *types.FieldType, args []Expression) (baseBuiltinFunc, error) {
if ctx == nil {
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}
bf := baseBuiltinFunc{
bufAllocator: newLocalSliceBuffer(len(args)),
childrenVectorizedOnce: new(sync.Once),
childrenReversedOnce: new(sync.Once),

args: args,
ctx: ctx,
tp: types.NewFieldType(mysql.TypeUnspecified),
}
bf.SetCharsetAndCollation(tp.Charset, tp.Collate)
bf.setCollator(collate.GetCollator(tp.Collate))
return bf, nil
}

func (b *baseBuiltinFunc) getArgs() []Expression {
return b.args
}
Expand Down
8 changes: 1 addition & 7 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tipb/go-tipb"
)
Expand All @@ -44,7 +43,7 @@ func PbTypeToFieldType(tp *tipb.FieldType) *types.FieldType {

func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *tipb.FieldType, args []Expression) (f builtinFunc, e error) {
fieldTp := PbTypeToFieldType(tp)
base, err := newBaseBuiltinFunc(ctx, fmt.Sprintf("PBSig-%v", sigCode), args)
base, err := newBaseBuiltinFuncWithFieldType(ctx, fieldTp, args)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1132,11 +1131,6 @@ func PBToExpr(expr *tipb.Expr, tps []*types.FieldType, sc *stmtctx.StatementCont
return nil, err
}

// recover collation information
if collate.NewCollationEnabled() {
tp := sf.GetType()
sf.SetCharsetAndCollation(tp.Charset, tp.Collate)
}
return sf, nil
}

Expand Down
15 changes: 15 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6719,3 +6719,18 @@ func (s *testIntegrationSerialSuite) TestIssue18702(c *C) {
tk.MustExec("ROLLBACK;")
tk.MustQuery("SELECT * FROM t FORCE INDEX(idx_bc);").Check(testkit.Rows("1 A 10 1", "2 B 20 1"))
}

func (s *testIntegrationSerialSuite) TestIssue18662(c *C) {
collate.SetNewCollationEnabledForTest(true)
defer collate.SetNewCollationEnabledForTest(false)

tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a varchar(10) collate utf8mb4_bin, b varchar(10) collate utf8mb4_general_ci);")
tk.MustExec("insert into t (a, b) values ('a', 'A');")
tk.MustQuery("select * from t where field('A', a collate utf8mb4_general_ci, b) > 1;").Check(testkit.Rows())
tk.MustQuery("select * from t where field('A', a, b collate utf8mb4_general_ci) > 1;").Check(testkit.Rows())
tk.MustQuery("select * from t where field('A' collate utf8mb4_general_ci, a, b) > 1;").Check(testkit.Rows())
tk.MustQuery("select * from t where field('A', a, b) > 1;").Check(testkit.Rows("a A"))
}

0 comments on commit 75171f0

Please sign in to comment.