diff --git a/expression/builtin.go b/expression/builtin.go index 0615e7b8ba161..29ee2009f011f 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -88,7 +88,7 @@ func (b *baseBuiltinFunc) collator() collate.Collator { func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expression) (baseBuiltinFunc, error) { if ctx == nil { - panic("ctx should not be nil") + return baseBuiltinFunc{}, errors.New("unexpected nil session ctx") } if err := checkIllegalMixCollation(funcName, args); err != nil { return baseBuiltinFunc{}, err @@ -134,7 +134,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex panic("unexpected length of args and argTps") } if ctx == nil { - panic("ctx should not be nil") + return baseBuiltinFunc{}, errors.New("unexpected nil session ctx") } for i := range args { @@ -244,6 +244,26 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex return bf, nil } +// newBaseBuiltinFuncWithFieldType create BaseBuiltinFunc with FieldType charset and collation. +// do not check and compute collation. +func newBaseBuiltinFuncWithFieldType(ctx sessionctx.Context, tp *types.FieldType, args []Expression) (baseBuiltinFunc, error) { + if ctx == nil { + return baseBuiltinFunc{}, errors.New("unexpected nil session ctx") + } + bf := baseBuiltinFunc{ + bufAllocator: newLocalSliceBuffer(len(args)), + childrenVectorizedOnce: new(sync.Once), + childrenReversedOnce: new(sync.Once), + + args: args, + ctx: ctx, + tp: types.NewFieldType(mysql.TypeUnspecified), + } + bf.SetCharsetAndCollation(tp.Charset, tp.Collate) + bf.setCollator(collate.GetCollator(tp.Collate)) + return bf, nil +} + func (b *baseBuiltinFunc) getArgs() []Expression { return b.args } diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 5de1c1248441b..63e31914854be 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" - "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tipb/go-tipb" ) @@ -44,7 +43,7 @@ func PbTypeToFieldType(tp *tipb.FieldType) *types.FieldType { func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *tipb.FieldType, args []Expression) (f builtinFunc, e error) { fieldTp := PbTypeToFieldType(tp) - base, err := newBaseBuiltinFunc(ctx, fmt.Sprintf("PBSig-%v", sigCode), args) + base, err := newBaseBuiltinFuncWithFieldType(ctx, fieldTp, args) if err != nil { return nil, err } @@ -1132,11 +1131,6 @@ func PBToExpr(expr *tipb.Expr, tps []*types.FieldType, sc *stmtctx.StatementCont return nil, err } - // recover collation information - if collate.NewCollationEnabled() { - tp := sf.GetType() - sf.SetCharsetAndCollation(tp.Charset, tp.Collate) - } return sf, nil } diff --git a/expression/integration_test.go b/expression/integration_test.go index 83e5781fcff63..45516dd621534 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -6719,3 +6719,18 @@ func (s *testIntegrationSerialSuite) TestIssue18702(c *C) { tk.MustExec("ROLLBACK;") tk.MustQuery("SELECT * FROM t FORCE INDEX(idx_bc);").Check(testkit.Rows("1 A 10 1", "2 B 20 1")) } + +func (s *testIntegrationSerialSuite) TestIssue18662(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a varchar(10) collate utf8mb4_bin, b varchar(10) collate utf8mb4_general_ci);") + tk.MustExec("insert into t (a, b) values ('a', 'A');") + tk.MustQuery("select * from t where field('A', a collate utf8mb4_general_ci, b) > 1;").Check(testkit.Rows()) + tk.MustQuery("select * from t where field('A', a, b collate utf8mb4_general_ci) > 1;").Check(testkit.Rows()) + tk.MustQuery("select * from t where field('A' collate utf8mb4_general_ci, a, b) > 1;").Check(testkit.Rows()) + tk.MustQuery("select * from t where field('A', a, b) > 1;").Check(testkit.Rows("a A")) +}