diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 21d02ad39c0ce..2fc71ffa4e4c6 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/pkg/ddl/resourcegroup" ddlutil "github.com/pingcap/tidb/pkg/ddl/util" rg "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" @@ -2267,7 +2268,8 @@ func BuildTableInfo( if len(hiddenCols) > 0 { AddIndexColumnFlag(tbInfo, idxInfo) } - _, err = validateCommentLength(ctx.GetSessionVars(), idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment) + sessionVars := ctx.GetSessionVars() + _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment) if err != nil { return nil, errors.Trace(err) } @@ -2621,7 +2623,8 @@ func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCh return nil, errors.Trace(err) } - if _, err = validateCommentLength(ctx.GetSessionVars(), tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil { + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil { return nil, errors.Trace(err) } @@ -5562,7 +5565,9 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col if col.Comment, err = value.ToString(); err != nil { return errors.Trace(err) } - col.Comment, err = validateCommentLength(ctx.GetSessionVars(), col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment) + + sessionVars := ctx.GetSessionVars() + col.Comment, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment) return errors.Trace(err) } @@ -6370,7 +6375,8 @@ func (d *ddl) AlterTableComment(ctx sessionctx.Context, ident ast.Ident, spec *a if err != nil { return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) } - if _, err = validateCommentLength(ctx.GetSessionVars(), ident.Name.L, &spec.Comment, dbterror.ErrTooLongTableComment); err != nil { + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, ident.Name.L, &spec.Comment, dbterror.ErrTooLongTableComment); err != nil { return errors.Trace(err) } @@ -7413,7 +7419,8 @@ func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName m // May be truncate comment here, when index comment too long and sql_mode is't strict. if indexOption != nil { - if _, err = validateCommentLength(ctx.GetSessionVars(), indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { return errors.Trace(err) } } @@ -7665,7 +7672,8 @@ func (d *ddl) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.Inde } // May be truncate comment here, when index comment too long and sql_mode is't strict. if indexOption != nil { - if _, err = validateCommentLength(ctx.GetSessionVars(), indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { return errors.Trace(err) } } @@ -8069,7 +8077,7 @@ func isDroppableColumn(tblInfo *model.TableInfo, colName model.CIStr) error { // validateCommentLength checks comment length of table, column, or index // If comment length is more than the standard length truncate it // and store the comment length upto the standard comment length size. -func validateCommentLength(vars *variable.SessionVars, name string, comment *string, errTooLongComment *terror.Error) (string, error) { +func validateCommentLength(ec errctx.Context, sqlMode mysql.SQLMode, name string, comment *string, errTooLongComment *terror.Error) (string, error) { if comment == nil { return "", nil } @@ -8086,11 +8094,11 @@ func validateCommentLength(vars *variable.SessionVars, name string, comment *str } if len(*comment) > maxLen { err := errTooLongComment.GenWithStackByArgs(name, maxLen) - if vars.SQLMode.HasStrictMode() { + if sqlMode.HasStrictMode() { // may be treated like an error. return "", err } - vars.StmtCtx.AppendWarning(err) + ec.AppendWarning(err) *comment = (*comment)[:maxLen] } return *comment, nil @@ -8231,7 +8239,8 @@ func checkAndGetColumnsTypeAndValuesMatch(ctx expression.BuildContext, colTypes return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() } } - newVal, err := val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &colType) + evalCtx := ctx.GetEvalCtx() + newVal, err := val.ConvertTo(evalCtx.TypeCtx(), &colType) if err != nil { return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() } diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index 32fe939bd36b8..622ddce7347db 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -803,12 +803,13 @@ func comparePartitionAstAndModel(ctx expression.BuildContext, pAst *ast.Partitio return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: number of partitions generated != partition defined (%d != %d)", len(a), len(m)) } + evalCtx := ctx.GetEvalCtx() evalFn := func(expr ast.ExprNode) (types.Datum, error) { val, err := expression.EvalSimpleAst(ctx, ast.NewValueExpr(expr, "", "")) if err != nil || partCol == nil { return val, err } - return val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &partCol.FieldType) + return val.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) } for i := range pAst.Definitions { // Allow options to differ! (like Placement Rules) @@ -840,7 +841,7 @@ func comparePartitionAstAndModel(ctx expression.BuildContext, pAst *ast.Partitio if err != nil { return err } - cmp, err := lessThanVal.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &generatedExprVal, collate.GetBinaryCollator()) + cmp, err := lessThanVal.Compare(evalCtx.TypeCtx(), &generatedExprVal, collate.GetBinaryCollator()) if err != nil { return err } @@ -1093,8 +1094,9 @@ func GeneratePartDefsFromInterval(ctx expression.BuildContext, tp ast.AlterTable if err != nil { return err } + evalCtx := ctx.GetEvalCtx() if partCol != nil { - lastVal, err = lastVal.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &partCol.FieldType) + lastVal, err = lastVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) if err != nil { return err } @@ -1143,12 +1145,12 @@ func GeneratePartDefsFromInterval(ctx expression.BuildContext, tp ast.AlterTable return err } if partCol != nil { - currVal, err = currVal.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &partCol.FieldType) + currVal, err = currVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) if err != nil { return err } } - cmp, err := currVal.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &lastVal, collate.GetBinaryCollator()) + cmp, err := currVal.Compare(evalCtx.TypeCtx(), &lastVal, collate.GetBinaryCollator()) if err != nil { return err } @@ -1416,7 +1418,8 @@ func buildRangePartitionDefinitions(ctx expression.BuildContext, defs []*ast.Par } } comment, _ := def.Comment() - comment, err := validateCommentLength(ctx.GetSessionVars(), def.Name.L, &comment, dbterror.ErrTooLongTablePartitionComment) + evalCtx := ctx.GetEvalCtx() + comment, err := validateCommentLength(evalCtx.ErrCtx(), evalCtx.SQLMode(), def.Name.L, &comment, dbterror.ErrTooLongTablePartitionComment) if err != nil { return nil, err } @@ -1499,7 +1502,8 @@ func checkPartitionValuesIsInt(ctx expression.BuildContext, defName any, exprs [ return dbterror.ErrValuesIsNotIntType.GenWithStackByArgs(defName) } - _, err = val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), tp) + evalCtx := ctx.GetEvalCtx() + _, err = val.ConvertTo(evalCtx.TypeCtx(), tp) if err != nil && !types.ErrOverflow.Equal(err) { return dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() } diff --git a/pkg/executor/aggfuncs/BUILD.bazel b/pkg/executor/aggfuncs/BUILD.bazel index 7f61a483a8413..c3c7a70767b84 100644 --- a/pkg/executor/aggfuncs/BUILD.bazel +++ b/pkg/executor/aggfuncs/BUILD.bazel @@ -40,7 +40,6 @@ go_library( "//pkg/parser/charset", "//pkg/parser/mysql", "//pkg/planner/util", - "//pkg/sessionctx/variable", "//pkg/types", "//pkg/util/chunk", "//pkg/util/codec", diff --git a/pkg/executor/aggfuncs/builder.go b/pkg/executor/aggfuncs/builder.go index cb2b803f6fb74..13214c6c9a572 100644 --- a/pkg/executor/aggfuncs/builder.go +++ b/pkg/executor/aggfuncs/builder.go @@ -15,16 +15,13 @@ package aggfuncs import ( - "context" "fmt" - "strconv" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/expression/aggregation" exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" @@ -33,7 +30,7 @@ import ( ) // AggFuncBuildContext is used to build aggregation functions. -type AggFuncBuildContext = exprctx.BuildContext +type AggFuncBuildContext = exprctx.AggFuncBuildContext // Build is used to build a specific AggFunc implementation according to the // input aggFuncDesc. @@ -287,7 +284,7 @@ func buildSum(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncDesc, ord if aggFuncDesc.HasDistinct { return &sum4DistinctFloat64{base} } - if ctx.GetSessionVars().WindowingUseHighPrecision { + if ctx.GetWindowingUseHighPrecision() { return &sum4Float64HighPrecision{baseSum4Float64{base}} } return &sum4Float64{baseSum4Float64{base}} @@ -321,7 +318,7 @@ func buildAvg(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncDesc, ord if aggFuncDesc.HasDistinct { return &avgOriginal4DistinctFloat64{base} } - if ctx.GetSessionVars().WindowingUseHighPrecision { + if ctx.GetWindowingUseHighPrecision() { return &avgOriginal4Float64HighPrecision{baseAvgFloat64{base}} } return &avgOriginal4Float64{avgOriginal4Float64HighPrecision{baseAvgFloat64{base}}} @@ -477,16 +474,7 @@ func buildGroupConcat(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncD if err != nil { panic(fmt.Sprintf("Error happened when buildGroupConcat: %s", err.Error())) } - var s string - s, err = ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.GroupConcatMaxLen) - if err != nil { - panic(fmt.Sprintf("Error happened when buildGroupConcat: no system variable named '%s'", variable.GroupConcatMaxLen)) - } - maxLen, err := strconv.ParseUint(s, 10, 64) - // Should never happen - if err != nil { - panic(fmt.Sprintf("Error happened when buildGroupConcat: %s", err.Error())) - } + maxLen := ctx.GetGroupConcatMaxLen() var truncated int32 base := baseGroupConcat4String{ baseAggFunc: baseAggFunc{ @@ -719,7 +707,8 @@ func buildLeadLag(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncDesc, if len(aggFuncDesc.Args) == 3 { defaultExpr = aggFuncDesc.Args[2] if et, ok := defaultExpr.(*expression.Constant); ok { - res, err1 := et.Value.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), aggFuncDesc.RetTp) + evalCtx := ctx.GetEvalCtx() + res, err1 := et.Value.ConvertTo(evalCtx.TypeCtx(), aggFuncDesc.RetTp) if err1 == nil { defaultExpr = &expression.Constant{Value: res, RetType: aggFuncDesc.RetTp} } diff --git a/pkg/expression/aggregation/BUILD.bazel b/pkg/expression/aggregation/BUILD.bazel index 5e1b2e200c2d8..502d43d527cbc 100644 --- a/pkg/expression/aggregation/BUILD.bazel +++ b/pkg/expression/aggregation/BUILD.bazel @@ -42,7 +42,6 @@ go_library( "//pkg/parser/terror", "//pkg/planner/util", "//pkg/sessionctx/stmtctx", - "//pkg/sessionctx/variable", "//pkg/types", "//pkg/util/chunk", "//pkg/util/codec", diff --git a/pkg/expression/aggregation/base_func.go b/pkg/expression/aggregation/base_func.go index e6a73811070e1..43876b8888d0c 100644 --- a/pkg/expression/aggregation/base_func.go +++ b/pkg/expression/aggregation/base_func.go @@ -96,7 +96,7 @@ func (a *baseFuncDesc) TypeInfer(ctx expression.BuildContext) error { case ast.AggFuncSum: a.typeInfer4Sum() case ast.AggFuncAvg: - a.typeInfer4Avg(ctx.GetSessionVars().GetDivPrecisionIncrement()) + a.typeInfer4Avg(ctx.GetEvalCtx().GetDivPrecisionIncrement()) case ast.AggFuncGroupConcat: a.typeInfer4GroupConcat(ctx) case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow, diff --git a/pkg/expression/aggregation/descriptor.go b/pkg/expression/aggregation/descriptor.go index cbb09e4351747..f83dd6c6bd1ae 100644 --- a/pkg/expression/aggregation/descriptor.go +++ b/pkg/expression/aggregation/descriptor.go @@ -16,17 +16,13 @@ package aggregation import ( "bytes" - "context" - "fmt" "math" - "strconv" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/size" @@ -218,7 +214,7 @@ func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx expression.BuildContext, sche } // GetAggFunc gets an evaluator according to the aggregation function signature. -func (a *AggFuncDesc) GetAggFunc(ctx expression.BuildContext) Aggregation { +func (a *AggFuncDesc) GetAggFunc(ctx expression.AggFuncBuildContext) Aggregation { aggFunc := aggFunction{AggFuncDesc: a} switch a.Name { case ast.AggFuncSum: @@ -228,18 +224,7 @@ func (a *AggFuncDesc) GetAggFunc(ctx expression.BuildContext) Aggregation { case ast.AggFuncAvg: return &avgFunction{aggFunction: aggFunc} case ast.AggFuncGroupConcat: - var s string - var err error - var maxLen uint64 - s, err = ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.GroupConcatMaxLen) - if err != nil { - panic(fmt.Sprintf("Error happened when GetAggFunc: no system variable named '%s'", variable.GroupConcatMaxLen)) - } - maxLen, err = strconv.ParseUint(s, 10, 64) - if err != nil { - panic(fmt.Sprintf("Error happened when GetAggFunc: illegal value for system variable named '%s'", variable.GroupConcatMaxLen)) - } - return &concatFunction{aggFunction: aggFunc, maxLen: maxLen} + return &concatFunction{aggFunction: aggFunc, maxLen: ctx.GetGroupConcatMaxLen()} case ast.AggFuncMax: return &maxMinFunction{aggFunction: aggFunc, isMax: true, ctor: collate.GetCollator(a.Args[0].GetType().GetCollate())} case ast.AggFuncMin: diff --git a/pkg/expression/bench_test.go b/pkg/expression/bench_test.go index e43afafa5549c..1144a68f99a8b 100644 --- a/pkg/expression/bench_test.go +++ b/pkg/expression/bench_test.go @@ -44,7 +44,7 @@ import ( ) type benchHelper struct { - ctx BuildContext + ctx *mock.Context exprs []Expression inputTypes []*types.FieldType diff --git a/pkg/expression/builtin_arithmetic.go b/pkg/expression/builtin_arithmetic.go index be9413f29ff64..71415c1638ffc 100644 --- a/pkg/expression/builtin_arithmetic.go +++ b/pkg/expression/builtin_arithmetic.go @@ -340,7 +340,7 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx BuildContext, args []Expr if err != nil { return nil, err } - if (mysql.HasUnsignedFlag(args[0].GetType().GetFlag()) || mysql.HasUnsignedFlag(args[1].GetType().GetFlag())) && !ctx.GetSessionVars().SQLMode.HasNoUnsignedSubtractionMode() { + if (mysql.HasUnsignedFlag(args[0].GetType().GetFlag()) || mysql.HasUnsignedFlag(args[1].GetType().GetFlag())) && !ctx.GetEvalCtx().SQLMode().HasNoUnsignedSubtractionMode() { bf.tp.AddFlag(mysql.UnsignedFlag) } sig := &builtinArithmeticMinusIntSig{baseBuiltinFunc: bf} @@ -661,7 +661,7 @@ func (c *arithmeticDivideFunctionClass) getFunction(ctx BuildContext, args []Exp if err != nil { return nil, err } - c.setType4DivDecimal(bf.tp, lhsTp, rhsTp, ctx.GetSessionVars().GetDivPrecisionIncrement()) + c.setType4DivDecimal(bf.tp, lhsTp, rhsTp, ctx.GetEvalCtx().GetDivPrecisionIncrement()) sig := &builtinArithmeticDivideDecimalSig{bf} sig.setPbCode(tipb.ScalarFuncSig_DivideDecimal) return sig, nil diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 889735e9af371..b623a525cfc5c 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -2270,7 +2270,7 @@ func WrapWithCastAsString(ctx BuildContext, expr Expression) Expression { tp.SetCharset(charset.CharsetBin) tp.SetCollate(charset.CollationBin) } else { - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() tp.SetCharset(charset) tp.SetCollate(collate) } diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index dbcfaa08e69ea..f837971e3901f 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -430,7 +430,7 @@ func unsupportedJSONComparison(ctx BuildContext, args []Expression) { for _, arg := range args { tp := arg.GetType().GetType() if tp == mysql.TypeJSON { - ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedJSONComparison) + ctx.GetEvalCtx().AppendWarning(errUnsupportedJSONComparison) break } } @@ -1383,13 +1383,14 @@ func tryToConvertConstantInt(ctx BuildContext, targetFieldType *types.FieldType, if con.GetType().EvalType() == types.ETInt { return con, false } - dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) + + evalCtx := ctx.GetEvalCtx() + dt, err := con.Eval(evalCtx, chunk.Row{}) if err != nil { return con, false } - sc := ctx.GetSessionVars().StmtCtx - dt, err = dt.ConvertTo(sc.TypeCtx(), targetFieldType) + dt, err = dt.ConvertTo(evalCtx.TypeCtx(), targetFieldType) if err != nil { if terror.ErrorEqual(err, types.ErrOverflow) { return &Constant{ @@ -1418,17 +1419,16 @@ func tryToConvertConstantInt(ctx BuildContext, targetFieldType *types.FieldType, // If the op == LT,LE,GT,GE and it gets an Overflow when converting, return inf/-inf. // If the op == EQ,NullEQ and the constant can never be equal to the int column, return ‘con’(the input, a non-int constant). func RefineComparedConstant(ctx BuildContext, targetFieldType types.FieldType, con *Constant, op opcode.Op) (_ *Constant, isExceptional bool) { - dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) + evalCtx := ctx.GetEvalCtx() + dt, err := con.Eval(evalCtx, chunk.Row{}) if err != nil { return con, false } - sc := ctx.GetSessionVars().StmtCtx - if targetFieldType.GetType() == mysql.TypeBit { targetFieldType = *types.NewFieldType(mysql.TypeLonglong) } var intDatum types.Datum - intDatum, err = dt.ConvertTo(sc.TypeCtx(), &targetFieldType) + intDatum, err = dt.ConvertTo(evalCtx.TypeCtx(), &targetFieldType) if err != nil { if terror.ErrorEqual(err, types.ErrOverflow) { return &Constant{ @@ -1440,7 +1440,7 @@ func RefineComparedConstant(ctx BuildContext, targetFieldType types.FieldType, c } return con, false } - c, err := intDatum.Compare(sc.TypeCtx(), &con.Value, collate.GetBinaryCollator()) + c, err := intDatum.Compare(evalCtx.TypeCtx(), &con.Value, collate.GetBinaryCollator()) if err != nil { return con, false } @@ -1485,7 +1485,7 @@ func RefineComparedConstant(ctx BuildContext, targetFieldType types.FieldType, c // 3. Suppose the value of `con` is 2, when `targetFieldType.GetType()` is `TypeYear`, the value of `doubleDatum` // will be 2.0 and the value of `intDatum` will be 2002 in this case. var doubleDatum types.Datum - doubleDatum, err = dt.ConvertTo(sc.TypeCtx(), types.NewFieldType(mysql.TypeDouble)) + doubleDatum, err = dt.ConvertTo(evalCtx.TypeCtx(), types.NewFieldType(mysql.TypeDouble)) if err != nil { return con, false } @@ -1533,7 +1533,7 @@ func allowCmpArgsRefining4PlanCache(ctx BuildContext, args []Expression) (allowR exprEvalType := exprType.EvalType() if exprType.GetType() == mysql.TypeYear { reason := errors.NewNoStackErrorf("'%v' may be converted to INT", args[conIdx].String()) - ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason) + ctx.SetSkipPlanCache(reason) return true } @@ -1543,7 +1543,7 @@ func allowCmpArgsRefining4PlanCache(ctx BuildContext, args []Expression) (allowR if exprEvalType == types.ETInt && (conEvalType == types.ETString || conEvalType == types.ETReal || conEvalType == types.ETDecimal) { reason := errors.NewNoStackErrorf("'%v' may be converted to INT", args[conIdx].String()) - ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason) + ctx.SetSkipPlanCache(reason) return true } @@ -1553,7 +1553,7 @@ func allowCmpArgsRefining4PlanCache(ctx BuildContext, args []Expression) (allowR _, exprIsCon := args[1-conIdx].(*Constant) if !exprIsCon && matchRefineRule3Pattern(conEvalType, exprType) { reason := errors.Errorf("'%v' may be converted to datetime", args[conIdx].String()) - ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason) + ctx.SetSkipPlanCache(reason) return true } } @@ -1678,14 +1678,14 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ( // see https://github.com/pingcap/tidb/issues/38361 for more details func (c *compareFunctionClass) refineNumericConstantCmpDatetime(ctx BuildContext, args []Expression, constArg *Constant, constArgIdx int) []Expression { - dt, err := constArg.Eval(ctx.GetEvalCtx(), chunk.Row{}) + evalCtx := ctx.GetEvalCtx() + dt, err := constArg.Eval(evalCtx, chunk.Row{}) if err != nil || dt.IsNull() { return args } - sc := ctx.GetSessionVars().StmtCtx var datetimeDatum types.Datum targetFieldType := types.NewFieldType(mysql.TypeDatetime) - datetimeDatum, err = dt.ConvertTo(sc.TypeCtx(), targetFieldType) + datetimeDatum, err = dt.ConvertTo(evalCtx.TypeCtx(), targetFieldType) if err != nil || datetimeDatum.IsNull() { return args } diff --git a/pkg/expression/builtin_encryption.go b/pkg/expression/builtin_encryption.go index abc2382b3e1f3..9666eee891ed3 100644 --- a/pkg/expression/builtin_encryption.go +++ b/pkg/expression/builtin_encryption.go @@ -123,7 +123,7 @@ func (c *aesDecryptFunctionClass) getFunction(ctx BuildContext, args []Expressio bf.tp.SetFlen(args[0].GetType().GetFlen()) // At most. types.SetBinChsClnFlag(bf.tp) - blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + blockMode := ctx.GetBlockEncryptionMode() mode, exists := aesModes[strings.ToLower(blockMode)] if !exists { return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) @@ -258,7 +258,7 @@ func (c *aesEncryptFunctionClass) getFunction(ctx BuildContext, args []Expressio bf.tp.SetFlen(aes.BlockSize * (args[0].GetType().GetFlen()/aes.BlockSize + 1)) // At most. types.SetBinChsClnFlag(bf.tp) - blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + blockMode := ctx.GetBlockEncryptionMode() mode, exists := aesModes[strings.ToLower(blockMode)] if !exists { return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) @@ -608,7 +608,7 @@ func (c *md5FunctionClass) getFunction(ctx BuildContext, args []Expression) (bui if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(32) @@ -651,7 +651,7 @@ func (c *sha1FunctionClass) getFunction(ctx BuildContext, args []Expression) (bu if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(40) @@ -698,7 +698,7 @@ func (c *sha2FunctionClass) getFunction(ctx BuildContext, args []Expression) (bu if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(128) // sha512 @@ -729,7 +729,7 @@ func (c *sm3FunctionClass) getFunction(ctx BuildContext, args []Expression) (bui if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(40) diff --git a/pkg/expression/builtin_info.go b/pkg/expression/builtin_info.go index e8383fbb33b0f..d772fd4767b25 100644 --- a/pkg/expression/builtin_info.go +++ b/pkg/expression/builtin_info.go @@ -758,7 +758,7 @@ func (c *charsetFunctionClass) getFunction(ctx BuildContext, args []Expression) if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(64) @@ -827,7 +827,7 @@ func (c *collationFunctionClass) getFunction(ctx BuildContext, args []Expression if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(64) @@ -1408,7 +1408,7 @@ func (c *formatBytesFunctionClass) getFunction(ctx BuildContext, args []Expressi if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) sig := &builtinFormatBytesSig{bf} @@ -1447,7 +1447,7 @@ func (c *formatNanoTimeFunctionClass) getFunction(ctx BuildContext, args []Expre if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) sig := &builtinFormatNanoTimeSig{bf} diff --git a/pkg/expression/builtin_json.go b/pkg/expression/builtin_json.go index 32f21d788998b..44c752dd31766 100644 --- a/pkg/expression/builtin_json.go +++ b/pkg/expression/builtin_json.go @@ -108,7 +108,7 @@ func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression) if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(51) // flen of JSON_TYPE is length of UNSIGNED INTEGER. diff --git a/pkg/expression/builtin_math.go b/pkg/expression/builtin_math.go index bf95e4294652d..a17307fbcfe16 100644 --- a/pkg/expression/builtin_math.go +++ b/pkg/expression/builtin_math.go @@ -1029,7 +1029,7 @@ func (c *randFunctionClass) getFunction(ctx BuildContext, args []Expression) (bu } bt := bf if len(args) == 0 { - sig = &builtinRandSig{bt, ctx.GetSessionVars().Rng} + sig = &builtinRandSig{bt, ctx.Rng()} sig.setPbCode(tipb.ScalarFuncSig_Rand) } else if _, isConstant := args[0].(*Constant); isConstant { // According to MySQL manual: @@ -1163,7 +1163,7 @@ func (c *convFunctionClass) getFunction(ctx BuildContext, args []Expression) (bu if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(64) diff --git a/pkg/expression/builtin_miscellaneous.go b/pkg/expression/builtin_miscellaneous.go index 5bc2d1a70e407..64960909ff590 100644 --- a/pkg/expression/builtin_miscellaneous.go +++ b/pkg/expression/builtin_miscellaneous.go @@ -588,7 +588,7 @@ func (c *inetNtoaFunctionClass) getFunction(ctx BuildContext, args []Expression) if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(93) @@ -716,7 +716,7 @@ func (c *inet6NtoaFunctionClass) getFunction(ctx BuildContext, args []Expression if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(117) @@ -1331,7 +1331,7 @@ func (c *uuidFunctionClass) getFunction(ctx BuildContext, args []Expression) (bu if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(36) @@ -1499,7 +1499,7 @@ func (c *binToUUIDFunctionClass) getFunction(ctx BuildContext, args []Expression return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(32) diff --git a/pkg/expression/builtin_op.go b/pkg/expression/builtin_op.go index 43ff23412b5a9..74263fb0c430a 100644 --- a/pkg/expression/builtin_op.go +++ b/pkg/expression/builtin_op.go @@ -739,7 +739,7 @@ func (c *unaryNotFunctionClass) getFunction(ctx BuildContext, args []Expression) sig = &builtinUnaryNotIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_UnaryNotInt) case types.ETJson: - ctx.GetSessionVars().StmtCtx.AppendWarning(errJSONInBooleanContext) + ctx.GetEvalCtx().AppendWarning(errJSONInBooleanContext) sig = &builtinUnaryNotJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_UnaryNotJSON) default: diff --git a/pkg/expression/builtin_other.go b/pkg/expression/builtin_other.go index 744a2e3a8d8e6..596db2863c880 100644 --- a/pkg/expression/builtin_other.go +++ b/pkg/expression/builtin_other.go @@ -165,7 +165,7 @@ func (c *inFunctionClass) verifyArgs(ctx BuildContext, args []Expression) ([]Exp case columnType.GetType() == mysql.TypeBit && constant.Value.Kind() == types.KindInt64: if constant.Value.GetInt64() < 0 { if MaybeOverOptimized4PlanCache(ctx, args) { - ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackErrorf("Bit Column in (%v)", constant.Value.GetInt64())) + ctx.SetSkipPlanCache(errors.NewNoStackErrorf("Bit Column in (%v)", constant.Value.GetInt64())) } continue } diff --git a/pkg/expression/builtin_string.go b/pkg/expression/builtin_string.go index e0771213915c6..d001d050d1785 100644 --- a/pkg/expression/builtin_string.go +++ b/pkg/expression/builtin_string.go @@ -816,7 +816,7 @@ func (c *spaceFunctionClass) getFunction(ctx BuildContext, args []Expression) (b if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(mysql.MaxBlobWidth) @@ -1622,7 +1622,7 @@ func (c *hexFunctionClass) getFunction(ctx BuildContext, args []Expression) (bui return nil, err } bf.tp.SetFlen(args[0].GetType().GetFlen() * 2) - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) sig := &builtinHexIntArgSig{bf} @@ -2738,7 +2738,7 @@ func (c *octFunctionClass) getFunction(ctx BuildContext, args []Expression) (bui if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) @@ -2751,7 +2751,7 @@ func (c *octFunctionClass) getFunction(ctx BuildContext, args []Expression) (bui if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(64) @@ -2966,7 +2966,7 @@ func (c *binFunctionClass) getFunction(ctx BuildContext, args []Expression) (bui if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(64) @@ -3248,7 +3248,7 @@ func (c *formatFunctionClass) getFunction(ctx BuildContext, args []Expression) ( if err != nil { return nil, err } - charset, colalte := ctx.GetSessionVars().GetCharsetInfo() + charset, colalte := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(colalte) bf.tp.SetFlen(mysql.MaxBlobWidth) @@ -3497,7 +3497,7 @@ func (c *toBase64FunctionClass) getFunction(ctx BuildContext, args []Expression) if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) @@ -3812,7 +3812,7 @@ func (c *loadFileFunctionClass) getFunction(ctx BuildContext, args []Expression) if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(64) diff --git a/pkg/expression/builtin_time.go b/pkg/expression/builtin_time.go index 6e3e4a053d939..62b7cfba02285 100644 --- a/pkg/expression/builtin_time.go +++ b/pkg/expression/builtin_time.go @@ -1080,7 +1080,7 @@ func (c *monthNameFunctionClass) getFunction(ctx BuildContext, args []Expression if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(10) @@ -1125,7 +1125,7 @@ func (c *dayNameFunctionClass) getFunction(ctx BuildContext, args []Expression) if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) bf.tp.SetFlen(10) @@ -4688,7 +4688,7 @@ func (c *addTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) sig.setPbCode(tipb.ScalarFuncSig_AddDatetimeAndString) } case mysql.TypeDate: - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) switch tp2.GetType() { @@ -5657,7 +5657,7 @@ func (c *subTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) sig.setPbCode(tipb.ScalarFuncSig_SubDatetimeAndString) } case mysql.TypeDate: - charset, collate := ctx.GetSessionVars().GetCharsetInfo() + charset, collate := ctx.GetCharsetInfo() bf.tp.SetCharset(charset) bf.tp.SetCollate(collate) switch tp2.GetType() { diff --git a/pkg/expression/collation.go b/pkg/expression/collation.go index 73f8408527319..b9a02e7223adb 100644 --- a/pkg/expression/collation.go +++ b/pkg/expression/collation.go @@ -240,7 +240,7 @@ func deriveCollation(ctx BuildContext, funcName string, args []Expression, retTy return CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args...) } case ast.DateFormat, ast.TimeFormat: - charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo() + charsetInfo, collation := ctx.GetCharsetInfo() return &ExprCollation{args[1].Coercibility(), args[1].Repertoire(), charsetInfo, collation}, nil case ast.Cast: // We assume all the cast are implicit. @@ -248,7 +248,7 @@ func deriveCollation(ctx BuildContext, funcName string, args []Expression, retTy // Non-string type cast to string type should use @@character_set_connection and @@collation_connection. // String type cast to string type should keep its original charset and collation. It should not happen. if retType == types.ETString && argTps[0] != types.ETString { - ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() + ec.Charset, ec.Collation = ctx.GetCharsetInfo() } return ec, nil case ast.Case: @@ -281,7 +281,7 @@ func deriveCollation(ctx BuildContext, funcName string, args []Expression, retTy case ast.Format, ast.Space, ast.ToBase64, ast.UUID, ast.Hex, ast.MD5, ast.SHA, ast.SHA2, ast.SM3: // should return ASCII repertoire, MySQL's doc says it depends on character_set_connection, but it not true from its source code. ec = &ExprCollation{Coer: CoercibilityCoercible, Repe: ASCII} - ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() + ec.Charset, ec.Collation = ctx.GetCharsetInfo() return ec, nil case ast.JSONPretty, ast.JSONQuote: // JSON function always return utf8mb4 and utf8mb4_bin. @@ -291,7 +291,7 @@ func deriveCollation(ctx BuildContext, funcName string, args []Expression, retTy ec = &ExprCollation{CoercibilityNumeric, ASCII, charset.CharsetBin, charset.CollationBin} if retType == types.ETString { - ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() + ec.Charset, ec.Collation = ctx.GetCharsetInfo() ec.Coer = CoercibilityCoercible if ec.Charset != charset.CharsetASCII { ec.Repe = UNICODE @@ -312,7 +312,7 @@ func CheckAndDeriveCollationFromExprs(ctx BuildContext, funcName string, evalTyp } if evalType == types.ETString && ec.Coer == CoercibilityNumeric { - ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() + ec.Charset, ec.Collation = ctx.GetCharsetInfo() ec.Coer = CoercibilityCoercible ec.Repe = ASCII } diff --git a/pkg/expression/constant_fold.go b/pkg/expression/constant_fold.go index ce51aaee189eb..55931403864ab 100644 --- a/pkg/expression/constant_fold.go +++ b/pkg/expression/constant_fold.go @@ -220,7 +220,8 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) { // of Constant to nil is ok. return &Constant{Value: value, RetType: x.RetType}, false } - if isTrue, err := value.ToBool(sc.TypeCtx()); err == nil && isTrue == 0 { + evalCtx := ctx.GetEvalCtx() + if isTrue, err := value.ToBool(evalCtx.TypeCtx()); err == nil && isTrue == 0 { // This Constant is created to compose the result expression of EvaluateExprWithNull when InNullRejectCheck // is true. We just check whether the result expression is null or false and then let it die. Basically, // the constant is used once briefly and will not be retained for a long time. Hence setting DeferredExpr diff --git a/pkg/expression/constant_propagation.go b/pkg/expression/constant_propagation.go index aa0985e0e12b5..e21cfbcc19075 100644 --- a/pkg/expression/constant_propagation.go +++ b/pkg/expression/constant_propagation.go @@ -55,13 +55,14 @@ func (s *basePropConstSolver) insertCol(col *Column) { // tryToUpdateEQList tries to update the eqList. When the eqList has store this column with a different constant, like // a = 1 and a = 2, we set the second return value to false. func (s *basePropConstSolver) tryToUpdateEQList(col *Column, con *Constant) (bool, bool) { - if con.Value.IsNull() && ConstExprConsiderPlanCache(con, s.ctx.GetSessionVars().StmtCtx.UseCache) { + if con.Value.IsNull() && ConstExprConsiderPlanCache(con, s.ctx.IsUseCache()) { return false, true } id := s.getColID(col) oldCon := s.eqList[id] if oldCon != nil { - res, err := oldCon.Value.Compare(s.ctx.GetSessionVars().StmtCtx.TypeCtx(), &con.Value, collate.GetCollator(col.GetType().GetCollate())) + evalCtx := s.ctx.GetEvalCtx() + res, err := oldCon.Value.Compare(evalCtx.TypeCtx(), &con.Value, collate.GetCollator(col.GetType().GetCollate())) return false, res != 0 || err != nil } s.eqList[id] = con @@ -281,7 +282,7 @@ func (s *propConstSolver) propagateColumnEQ() { func (s *propConstSolver) setConds2ConstFalse() { if MaybeOverOptimized4PlanCache(s.ctx, s.conditions) { - s.ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("some parameters may be overwritten when constant propagation")) + s.ctx.SetSkipPlanCache(errors.New("some parameters may be overwritten when constant propagation")) } s.conditions = []Expression{&Constant{ Value: types.NewDatum(false), @@ -397,7 +398,7 @@ func (s *basePropConstSolver) dealWithPossibleHybridType(col *Column, con *Const return nil, false } if MaybeOverOptimized4PlanCache(s.ctx, []Expression{con}) { - s.ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("Skip plan cache since mutable constant is restored and propagated")) + s.ctx.SetSkipPlanCache(errors.New("Skip plan cache since mutable constant is restored and propagated")) } switch d.Kind() { case types.KindInt64: diff --git a/pkg/expression/context.go b/pkg/expression/context.go index 24ba84f3f5967..1b1fe8cf67e46 100644 --- a/pkg/expression/context.go +++ b/pkg/expression/context.go @@ -32,6 +32,9 @@ type EvalContext = context.EvalContext // BuildContext is used to build an expression type BuildContext = context.BuildContext +// AggFuncBuildContext is used to build an aggregation expression +type AggFuncBuildContext = context.AggFuncBuildContext + // OptionalEvalPropKey is an alias of context.OptionalEvalPropKey type OptionalEvalPropKey = context.OptionalEvalPropKey diff --git a/pkg/expression/context/BUILD.bazel b/pkg/expression/context/BUILD.bazel index 8215e137de48f..5d1373913b8d5 100644 --- a/pkg/expression/context/BUILD.bazel +++ b/pkg/expression/context/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//pkg/sessionctx/variable", "//pkg/types", "//pkg/util/intest", + "//pkg/util/mathutil", ], ) diff --git a/pkg/expression/context/context.go b/pkg/expression/context/context.go index 866499fab1433..c1d8fcd5802ec 100644 --- a/pkg/expression/context/context.go +++ b/pkg/expression/context/context.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/mathutil" ) // EvalContext is used to evaluate an expression @@ -68,6 +69,23 @@ type EvalContext interface { type BuildContext interface { // GetEvalCtx returns the EvalContext. GetEvalCtx() EvalContext + // GetCharsetInfo gets charset and collation for current context. + GetCharsetInfo() (string, string) + // GetDefaultCollationForUTF8MB4 returns the default collation of UTF8MB4. + GetDefaultCollationForUTF8MB4() string + // GetBlockEncryptionMode returns the variable `block_encryption_mode`. + GetBlockEncryptionMode() string + // GetSysdateIsNow returns a bool to determine whether Sysdate is an alias of Now function. + // It is the value of variable `tidb_sysdate_is_now`. + GetSysdateIsNow() bool + // GetNoopFuncsMode returns the noop function mode: OFF/ON/WARN values as 0/1/2. + GetNoopFuncsMode() int + // Rng is used to generate random values. + Rng() *mathutil.MysqlRng + // IsUseCache indicates whether to cache the build expression in plan cache. + IsUseCache() bool + // SetSkipPlanCache sets to skip the plan cache and records the reason. + SetSkipPlanCache(reason error) // GetSessionVars gets the session variables. GetSessionVars() *variable.SessionVars // Value returns the value associated with this context for key. @@ -75,3 +93,19 @@ type BuildContext interface { // SetValue saves a value associated with this context for key. SetValue(key fmt.Stringer, value any) } + +// AggFuncBuildContext is used to build an aggregate function +type AggFuncBuildContext interface { + BuildContext + // GetWindowingUseHighPrecision determines whether to compute window operations without loss of precision. + // see https://dev.mysql.com/doc/refman/8.0/en/window-function-optimization.html for more details. + GetWindowingUseHighPrecision() bool + // GetGroupConcatMaxLen returns the value of the 'group_concat_max_len' system variable. + GetGroupConcatMaxLen() uint64 +} + +// ExprContext contains full context for expression building and evaluating. +type ExprContext interface { + BuildContext + AggFuncBuildContext +} diff --git a/pkg/expression/contextimpl/BUILD.bazel b/pkg/expression/contextimpl/BUILD.bazel index 484b498ada06f..aafacc17c4f12 100644 --- a/pkg/expression/contextimpl/BUILD.bazel +++ b/pkg/expression/contextimpl/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "//pkg/util", "//pkg/util/intest", "//pkg/util/logutil", + "//pkg/util/mathutil", "@com_github_tikv_client_go_v2//oracle", "@org_uber_go_zap//:zap", ], @@ -31,7 +32,7 @@ go_test( timeout = "short", srcs = ["sessionctx_test.go"], flaky = True, - shard_count = 4, + shard_count = 5, deps = [ ":contextimpl", "//pkg/errctx", @@ -42,6 +43,7 @@ go_test( "//pkg/privilege", "//pkg/sessionctx/stmtctx", "//pkg/types", + "//pkg/util/mathutil", "//pkg/util/mock", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//mock", diff --git a/pkg/expression/contextimpl/sessionctx.go b/pkg/expression/contextimpl/sessionctx.go index 0b3fc9ab44e1d..5d792f2c9b983 100644 --- a/pkg/expression/contextimpl/sessionctx.go +++ b/pkg/expression/contextimpl/sessionctx.go @@ -34,25 +34,29 @@ import ( "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/mathutil" "github.com/tikv/client-go/v2/oracle" "go.uber.org/zap" ) -// sessionctx.Context + *ExprCtxExtendedImpl should implement `expression.BuildContext` +// sessionctx.Context + *ExprCtxExtendedImpl should implement +// `expression.BuildContext` and `expression.AggFuncBuildContext` // Only used to assert `ExprCtxExtendedImpl` should implement all methods not in `sessionctx.Context` -var _ exprctx.BuildContext = struct { +var _ exprctx.ExprContext = struct { sessionctx.Context *ExprCtxExtendedImpl }{} // ExprCtxExtendedImpl extends the sessionctx.Context to implement `expression.BuildContext` type ExprCtxExtendedImpl struct { + sctx sessionctx.Context *SessionEvalContext } // NewExprExtendedImpl creates a new ExprCtxExtendedImpl. func NewExprExtendedImpl(sctx sessionctx.Context) *ExprCtxExtendedImpl { return &ExprCtxExtendedImpl{ + sctx: sctx, SessionEvalContext: NewSessionEvalContext(sctx), } } @@ -62,6 +66,60 @@ func (ctx *ExprCtxExtendedImpl) GetEvalCtx() exprctx.EvalContext { return ctx.SessionEvalContext } +// GetCharsetInfo gets charset and collation for current context. +func (ctx *ExprCtxExtendedImpl) GetCharsetInfo() (string, string) { + return ctx.sctx.GetSessionVars().GetCharsetInfo() +} + +// GetDefaultCollationForUTF8MB4 returns the default collation of UTF8MB4. +func (ctx *ExprCtxExtendedImpl) GetDefaultCollationForUTF8MB4() string { + return ctx.sctx.GetSessionVars().DefaultCollationForUTF8MB4 +} + +// GetBlockEncryptionMode returns the variable block_encryption_mode +func (ctx *ExprCtxExtendedImpl) GetBlockEncryptionMode() string { + blockMode, _ := ctx.sctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + return blockMode +} + +// GetSysdateIsNow returns a bool to determine whether Sysdate is an alias of Now function. +// It is the value of variable `tidb_sysdate_is_now`. +func (ctx *ExprCtxExtendedImpl) GetSysdateIsNow() bool { + return ctx.sctx.GetSessionVars().SysdateIsNow +} + +// GetNoopFuncsMode returns the noop function mode: OFF/ON/WARN values as 0/1/2. +func (ctx *ExprCtxExtendedImpl) GetNoopFuncsMode() int { + return ctx.sctx.GetSessionVars().NoopFuncsMode +} + +// Rng is used to generate random values. +func (ctx *ExprCtxExtendedImpl) Rng() *mathutil.MysqlRng { + return ctx.sctx.GetSessionVars().Rng +} + +// IsUseCache indicates whether to cache the build expression in plan cache. +// If SetSkipPlanCache is invoked, it should return false. +func (ctx *ExprCtxExtendedImpl) IsUseCache() bool { + return ctx.sctx.GetSessionVars().StmtCtx.UseCache +} + +// SetSkipPlanCache sets to skip the plan cache and records the reason. +func (ctx *ExprCtxExtendedImpl) SetSkipPlanCache(reason error) { + ctx.sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason) +} + +// GetWindowingUseHighPrecision determines whether to compute window operations without loss of precision. +// see https://dev.mysql.com/doc/refman/8.0/en/window-function-optimization.html for more details. +func (ctx *ExprCtxExtendedImpl) GetWindowingUseHighPrecision() bool { + return ctx.sctx.GetSessionVars().WindowingUseHighPrecision +} + +// GetGroupConcatMaxLen returns the value of the 'group_concat_max_len' system variable. +func (ctx *ExprCtxExtendedImpl) GetGroupConcatMaxLen() uint64 { + return ctx.sctx.GetSessionVars().GroupConcatMaxLen +} + // SessionEvalContext implements the `expression.EvalContext` interface to provide evaluation context in session. type SessionEvalContext struct { sctx sessionctx.Context @@ -92,6 +150,11 @@ func (ctx *SessionEvalContext) setOptionalProp(prop exprctx.OptionalEvalPropProv ctx.props.Add(prop) } +// Sctx returns the innert session context +func (ctx *SessionEvalContext) Sctx() sessionctx.Context { + return ctx.sctx +} + // CtxID returns the context id. func (ctx *SessionEvalContext) CtxID() uint64 { return ctx.sctx.GetSessionVars().StmtCtx.CtxID() diff --git a/pkg/expression/contextimpl/sessionctx_test.go b/pkg/expression/contextimpl/sessionctx_test.go index f12df1c751055..2b694225f99c4 100644 --- a/pkg/expression/contextimpl/sessionctx_test.go +++ b/pkg/expression/contextimpl/sessionctx_test.go @@ -15,6 +15,7 @@ package contextimpl_test import ( + "github.com/pingcap/tidb/pkg/util/mathutil" "sync/atomic" "testing" "time" @@ -39,7 +40,7 @@ func TestSessionEvalContextBasic(t *testing.T) { ctx := mock.NewContext() vars := ctx.GetSessionVars() sc := vars.StmtCtx - impl := contextimpl.NewExprExtendedImpl(ctx) + impl := contextimpl.NewSessionEvalContext(ctx) require.True(t, impl.GetOptionalPropSet().IsFull()) // should contain all the optional properties @@ -97,7 +98,7 @@ func TestSessionEvalContextCurrentTime(t *testing.T) { ctx := mock.NewContext() vars := ctx.GetSessionVars() sc := vars.StmtCtx - impl := contextimpl.NewExprExtendedImpl(ctx) + impl := contextimpl.NewSessionEvalContext(ctx) var now atomic.Pointer[time.Time] sc.SetStaleTSOProvider(func() (uint64, error) { @@ -164,7 +165,7 @@ func (m *mockPrivManager) RequestDynamicVerification( func TestSessionEvalContextPrivilegeCheck(t *testing.T) { ctx := mock.NewContext() - impl := contextimpl.NewExprExtendedImpl(ctx) + impl := contextimpl.NewSessionEvalContext(ctx) activeRoles := []*auth.RoleIdentity{ {Username: "role1", Hostname: "host1"}, {Username: "role2", Hostname: "host2"}, @@ -201,7 +202,7 @@ func TestSessionEvalContextPrivilegeCheck(t *testing.T) { func getProvider[T context.OptionalEvalPropProvider]( t *testing.T, - impl *contextimpl.ExprCtxExtendedImpl, + impl *contextimpl.SessionEvalContext, key context.OptionalEvalPropKey, ) T { val, ok := impl.GetOptionalPropProvider(key) @@ -214,7 +215,7 @@ func getProvider[T context.OptionalEvalPropProvider]( func TestSessionEvalContextOptProps(t *testing.T) { ctx := mock.NewContext() - impl := contextimpl.NewExprExtendedImpl(ctx) + impl := contextimpl.NewSessionEvalContext(ctx) // test for OptPropCurrentUser ctx.GetSessionVars().User = &auth.UserIdentity{Username: "user1", Hostname: "host1"} @@ -242,3 +243,39 @@ func TestSessionEvalContextOptProps(t *testing.T) { ctx.SetIsDDLOwner(true) require.True(t, ddlInfoProvider()) } + +func TestSessionBuildContext(t *testing.T) { + ctx := mock.NewContext() + impl := contextimpl.NewExprExtendedImpl(ctx) + evalCtx, ok := impl.GetEvalCtx().(*contextimpl.SessionEvalContext) + require.True(t, ok) + require.Same(t, evalCtx, impl.SessionEvalContext) + require.True(t, evalCtx.GetOptionalPropSet().IsFull()) + require.Same(t, ctx, evalCtx.Sctx()) + + vars := ctx.GetSessionVars() + err := vars.SetSystemVar("character_set_connection", "gbk") + require.NoError(t, err) + err = vars.SetSystemVar("collation_connection", "gbk_chinese_ci") + require.NoError(t, err) + vars.DefaultCollationForUTF8MB4 = "utf8mb4_0900_ai_ci" + + charset, collate := impl.GetCharsetInfo() + require.Equal(t, "gbk", charset) + require.Equal(t, "gbk_chinese_ci", collate) + require.Equal(t, "utf8mb4_0900_ai_ci", impl.GetDefaultCollationForUTF8MB4()) + + vars.SysdateIsNow = true + require.True(t, impl.GetSysdateIsNow()) + + vars.NoopFuncsMode = 2 + require.Equal(t, 2, impl.GetNoopFuncsMode()) + + vars.Rng = mathutil.NewWithSeed(123) + require.Same(t, vars.Rng, impl.Rng()) + + vars.StmtCtx.UseCache = true + require.True(t, impl.IsUseCache()) + impl.SetSkipPlanCache(errors.New("mockReason")) + require.False(t, impl.IsUseCache()) +} diff --git a/pkg/expression/distsql_builtin.go b/pkg/expression/distsql_builtin.go index 50a0a97efe360..99d942891b757 100644 --- a/pkg/expression/distsql_builtin.go +++ b/pkg/expression/distsql_builtin.go @@ -1113,7 +1113,7 @@ func PBToExprs(ctx BuildContext, pbExprs []*tipb.Expr, fieldTps []*types.FieldTy // PBToExpr converts pb structure to expression. func PBToExpr(ctx BuildContext, expr *tipb.Expr, tps []*types.FieldType) (Expression, error) { - sc := ctx.GetSessionVars().StmtCtx + evalCtx := ctx.GetEvalCtx() switch expr.Tp { case tipb.ExprType_ColumnRef: _, offset, err := codec.DecodeInt(expr.Val) @@ -1142,7 +1142,7 @@ func PBToExpr(ctx BuildContext, expr *tipb.Expr, tps []*types.FieldType) (Expres case tipb.ExprType_MysqlDuration: return convertDuration(expr.Val) case tipb.ExprType_MysqlTime: - return convertTime(expr.Val, expr.FieldType, sc.TimeZone()) + return convertTime(expr.Val, expr.FieldType, evalCtx.Location()) case tipb.ExprType_MysqlJson: return convertJSON(expr.Val) case tipb.ExprType_MysqlEnum: diff --git a/pkg/expression/evaluator_test.go b/pkg/expression/evaluator_test.go index a26cdc9398b50..b9a08f3c264fe 100644 --- a/pkg/expression/evaluator_test.go +++ b/pkg/expression/evaluator_test.go @@ -91,7 +91,7 @@ func datumsToConstants(datums []types.Datum) []Expression { func primitiveValsToConstants(ctx BuildContext, args []any) []Expression { cons := datumsToConstants(types.MakeDatums(args...)) - char, col := ctx.GetSessionVars().GetCharsetInfo() + char, col := ctx.GetCharsetInfo() for i, arg := range args { types.DefaultTypeForValue(arg, cons[i].GetType(), char, col) } diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index 63783e3ac91f2..ca049f70c4747 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -863,7 +863,7 @@ func SplitDNFItems(onExpr Expression) []Expression { // If the Expression is a non-constant value, it means the result is unknown. func EvaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Expression { if MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) { - ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("%v affects null check")) + ctx.SetSkipPlanCache(errors.NewNoStackError("%v affects null check")) } if ctx.GetSessionVars().StmtCtx.InNullRejectCheck { expr, _ = evaluateExprWithNullInNullRejectCheck(ctx, schema, expr) diff --git a/pkg/expression/extension.go b/pkg/expression/extension.go index 762192baa3170..09b8a087f12cb 100644 --- a/pkg/expression/extension.go +++ b/pkg/expression/extension.go @@ -111,7 +111,7 @@ func (c *extensionFuncClass) getFunction(ctx BuildContext, args []Expression) (b // Though currently, `getFunction` does not require too much information that makes it safe to be cached, // we still skip the plan cache for extension functions because there are no strong requirements to do it. // Skipping the plan cache can make the behavior simple. - ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("extension function should not be cached")) + ctx.SetSkipPlanCache(errors.NewNoStackError("extension function should not be cached")) bf.tp.SetFlen(c.flen) sig := &extensionFuncSig{baseBuiltinFunc: bf, FunctionDef: c.funcDef} return sig, nil diff --git a/pkg/expression/function_traits.go b/pkg/expression/function_traits.go index 83d1bcffeb3b9..dc9cbca9a48e3 100644 --- a/pkg/expression/function_traits.go +++ b/pkg/expression/function_traits.go @@ -141,7 +141,7 @@ func IsDeferredFunctions(ctx BuildContext, fn string) bool { if ok { return ok } - if fn == ast.Sysdate && ctx.GetSessionVars().SysdateIsNow { + if fn == ast.Sysdate && ctx.GetSysdateIsNow() { return true } return ok diff --git a/pkg/expression/helper.go b/pkg/expression/helper.go index 3c91f5f7274fb..da2a560676964 100644 --- a/pkg/expression/helper.go +++ b/pkg/expression/helper.go @@ -73,7 +73,7 @@ func getTimeCurrentTimeStamp(ctx BuildContext, tp byte, fsp int) (t types.Time, } value.SetCoreTime(types.FromGoTime(defaultTime.Truncate(time.Duration(math.Pow10(9-fsp)) * time.Nanosecond))) if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate { - err = value.ConvertTimeZone(time.Local, ctx.GetSessionVars().Location()) + err = value.ConvertTimeZone(time.Local, ctx.GetEvalCtx().Location()) if err != nil { return value, err } @@ -84,7 +84,7 @@ func getTimeCurrentTimeStamp(ctx BuildContext, tp byte, fsp int) (t types.Time, // GetTimeValue gets the time value with type tp. func GetTimeValue(ctx BuildContext, v any, tp byte, fsp int, explicitTz *time.Location) (d types.Datum, err error) { var value types.Time - tc := ctx.GetSessionVars().StmtCtx.TypeCtx() + tc := ctx.GetEvalCtx().TypeCtx() if explicitTz != nil { tc = tc.WithLocation(explicitTz) } diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 528d9db108b8a..fcef118dc5002 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -205,7 +205,7 @@ func newFunctionImpl(ctx BuildContext, fold int, funcName string, retType *types case InternalFuncToBinary: return BuildToBinaryFunction(ctx, args[0]), nil case ast.Sysdate: - if ctx.GetSessionVars().SysdateIsNow { + if ctx.GetSysdateIsNow() { funcName = ast.Now } } @@ -224,7 +224,7 @@ func newFunctionImpl(ctx BuildContext, fold int, funcName string, retType *types } return nil, ErrFunctionNotExists.GenWithStackByArgs("FUNCTION", db+"."+funcName) } - noopFuncsMode := ctx.GetSessionVars().NoopFuncsMode + noopFuncsMode := ctx.GetNoopFuncsMode() if noopFuncsMode != variable.OnInt { if _, ok := noopFuncs[funcName]; ok { err := ErrFunctionsNoopImpl.FastGenByArgs(funcName) @@ -232,7 +232,7 @@ func newFunctionImpl(ctx BuildContext, fold int, funcName string, retType *types return nil, errors.Trace(err) } // NoopFuncsMode is Warn, append an error - ctx.GetSessionVars().StmtCtx.AppendWarning(err) + ctx.GetEvalCtx().AppendWarning(err) } } funcArgs := make([]Expression, len(args)) @@ -264,12 +264,12 @@ func newFunctionImpl(ctx BuildContext, fold int, funcName string, retType *types return FoldConstant(ctx, sf), nil } else if fold == -1 { // try to fold constants, and return the original function if errors/warnings occur - sc := ctx.GetSessionVars().StmtCtx - beforeWarns := sc.WarningCount() + evalCtx := ctx.GetEvalCtx() + beforeWarns := evalCtx.WarningCount() newSf := FoldConstant(ctx, sf) - afterWarns := sc.WarningCount() + afterWarns := evalCtx.WarningCount() if afterWarns > beforeWarns { - sc.TruncateWarnings(int(beforeWarns)) + evalCtx.TruncateWarnings(beforeWarns) return sf, nil } return newSf, nil diff --git a/pkg/expression/simple_rewriter.go b/pkg/expression/simple_rewriter.go index 7f733babada3b..1918c9731f3ad 100644 --- a/pkg/expression/simple_rewriter.go +++ b/pkg/expression/simple_rewriter.go @@ -51,7 +51,7 @@ func ParseSimpleExpr(ctx BuildContext, exprStr string, opts ...BuildOption) (Exp stmts, warns, err = parser.New().ParseSQL(exprStr) } for _, warn := range warns { - ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + ctx.GetEvalCtx().AppendWarning(util.SyntaxWarn(warn)) } if err != nil { diff --git a/pkg/expression/util.go b/pkg/expression/util.go index 554c87302af9d..0c7952081e5a9 100644 --- a/pkg/expression/util.go +++ b/pkg/expression/util.go @@ -1465,7 +1465,7 @@ func ContainCorrelatedColumn(exprs []Expression) bool { // TODO: Do more careful check here. func MaybeOverOptimized4PlanCache(ctx BuildContext, exprs []Expression) bool { // If we do not enable plan cache, all the optimization can work correctly. - if !ctx.GetSessionVars().StmtCtx.UseCache { + if !ctx.IsUseCache() { return false } return containMutableConst(ctx.GetEvalCtx(), exprs) diff --git a/pkg/lightning/backend/kv/session.go b/pkg/lightning/backend/kv/session.go index b48d0ca1c3ba6..1f4d35a965503 100644 --- a/pkg/lightning/backend/kv/session.go +++ b/pkg/lightning/backend/kv/session.go @@ -405,7 +405,7 @@ func (se *Session) GetPlanCtx() planctx.PlanContext { } // GetExprCtx returns the expression context of the session. -func (se *Session) GetExprCtx() exprctx.BuildContext { +func (se *Session) GetExprCtx() exprctx.ExprContext { return se.exprCtx } diff --git a/pkg/planner/context/context.go b/pkg/planner/context/context.go index ff17f7e9f34c4..3acd970769726 100644 --- a/pkg/planner/context/context.go +++ b/pkg/planner/context/context.go @@ -35,7 +35,7 @@ type PlanContext interface { // GetRestrictedSQLExecutor gets the RestrictedSQLExecutor. GetRestrictedSQLExecutor() sqlexec.RestrictedSQLExecutor // GetExprCtx gets the expression context. - GetExprCtx() exprctx.BuildContext + GetExprCtx() exprctx.ExprContext // GetStore returns the store of session. GetStore() kv.Storage // GetSessionVars gets the session variables. diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index b25481af2fd48..6b179f8788035 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -699,7 +699,7 @@ func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, planCtx noDecorrelate := hintFlags&hint.HintFlagNoDecorrelate > 0 if noDecorrelate && len(extractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())) == 0 { - er.sctx.GetSessionVars().StmtCtx.SetHintWarning( + b.ctx.GetSessionVars().StmtCtx.SetHintWarning( "NO_DECORRELATE() is inapplicable because there are no correlated columns.") noDecorrelate = false } @@ -804,7 +804,7 @@ func (er *expressionRewriter) handleOtherComparableSubq(planCtx *exprRewriterPla // Create a column and append it to the schema of that aggregation. colMaxOrMin := &expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID(), RetType: funcMaxOrMin.RetTp, } colMaxOrMin.SetCoercibility(rexpr.Coercibility()) @@ -829,8 +829,9 @@ func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, er.err = err return } + sessVars := planCtx.builder.ctx.GetSessionVars() colSum := &expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: sessVars.AllocPlanColumnID(), RetType: funcSum.RetTp, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum) @@ -844,7 +845,7 @@ func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, return } colCount := &expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: sessVars.AllocPlanColumnID(), RetType: funcCount.RetTp, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) @@ -891,7 +892,7 @@ func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, proj.SetSchema(expression.NewSchema(joinSchema.Clone().Columns[:outerSchemaLen]...)) proj.Exprs = append(proj.Exprs, cond) proj.schema.Append(&expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: sessVars.AllocPlanColumnID(), RetType: cond.GetType(), }) proj.names = append(proj.names, types.EmptyName) @@ -927,12 +928,12 @@ func (er *expressionRewriter) handleNEAny(planCtx *exprRewriterPlanCtx, lexpr, r } plan4Agg.SetChildren(np) maxResultCol := &expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), RetType: maxFunc.RetTp, } maxResultCol.SetCoercibility(rexpr.Coercibility()) count := &expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), RetType: countFunc.RetTp, } plan4Agg.names = append(plan4Agg.names, types.EmptyName, types.EmptyName) @@ -980,13 +981,13 @@ func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, r firstRowFunc.RetTp = newRetTp firstRowResultCol := &expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), RetType: firstRowFunc.RetTp, } firstRowResultCol.SetCoercibility(rexpr.Coercibility()) plan4Agg.names = append(plan4Agg.names, types.EmptyName) count := &expression.Column{ - UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), RetType: countFunc.RetTp, } plan4Agg.SetSchema(expression.NewSchema(firstRowResultCol, count)) @@ -1015,13 +1016,13 @@ func (er *expressionRewriter) handleExistSubquery(ctx context.Context, planCtx * noDecorrelate := hintFlags&hint.HintFlagNoDecorrelate > 0 if noDecorrelate && len(extractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())) == 0 { - er.sctx.GetSessionVars().StmtCtx.SetHintWarning( + b.ctx.GetSessionVars().StmtCtx.SetHintWarning( "NO_DECORRELATE() is inapplicable because there are no correlated columns.") noDecorrelate = false } semiJoinRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0 if semiJoinRewrite && noDecorrelate { - er.sctx.GetSessionVars().StmtCtx.SetHintWarning( + b.ctx.GetSessionVars().StmtCtx.SetHintWarning( "NO_DECORRELATE() and SEMI_JOIN_REWRITE() are in conflict. Both will be ineffective.") noDecorrelate = false semiJoinRewrite = false @@ -1035,10 +1036,10 @@ func (er *expressionRewriter) handleExistSubquery(ctx context.Context, planCtx * er.ctxStackAppend(planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1], planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1]) } else { // We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily. - nthPlanBackup := er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan - er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1 + nthPlanBackup := b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan + b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1 physicalPlan, _, err := DoOptimize(ctx, planCtx.builder.ctx, b.optFlag, np) - er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup + b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup if err != nil { er.err = err return v, true @@ -1191,7 +1192,7 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, planCtx *exp noDecorrelate := hintFlags&hint.HintFlagNoDecorrelate > 0 corCols := extractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema()) if len(corCols) == 0 && noDecorrelate { - er.sctx.GetSessionVars().StmtCtx.SetHintWarning( + planCtx.builder.ctx.GetSessionVars().StmtCtx.SetHintWarning( "NO_DECORRELATE() is inapplicable because there are no correlated columns.") noDecorrelate = false } @@ -1200,7 +1201,7 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, planCtx *exp // and has no correlated column from the current level plan(if the correlated column is from upper level, // we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node), // and don't need to append a scalar value, we can rewrite it to inner join. - if er.sctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(corCols) == 0 && collFlag { + if planCtx.builder.ctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(corCols) == 0 && collFlag { // We need to try to eliminate the agg and the projection produced by this operation. planCtx.builder.optFlag |= flagEliminateAgg planCtx.builder.optFlag |= flagEliminateProjection @@ -1252,7 +1253,7 @@ func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, planCtx noDecorrelate := hintFlags&hint.HintFlagNoDecorrelate > 0 if noDecorrelate && len(extractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())) == 0 { - er.sctx.GetSessionVars().StmtCtx.SetHintWarning( + planCtx.builder.ctx.GetSessionVars().StmtCtx.SetHintWarning( "NO_DECORRELATE() is inapplicable because there are no correlated columns.") noDecorrelate = false } @@ -1276,10 +1277,10 @@ func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, planCtx return v, true } // We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily. - nthPlanBackup := er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan - er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1 + nthPlanBackup := planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan + planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1 physicalPlan, _, err := DoOptimize(ctx, planCtx.builder.ctx, planCtx.builder.optFlag, np) - er.sctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup + planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup if err != nil { er.err = err return v, true @@ -1376,7 +1377,7 @@ func initConstantRepertoire(c *expression.Constant) { func (er *expressionRewriter) adjustUTF8MB4Collation(tp *types.FieldType) { if tp.GetFlag()&mysql.UnderScoreCharsetFlag > 0 && charset.CharsetUTF8MB4 == tp.GetCharset() { - tp.SetCollate(er.sctx.GetSessionVars().DefaultCollationForUTF8MB4) + tp.SetCollate(er.sctx.GetDefaultCollationForUTF8MB4()) } } @@ -1634,7 +1635,7 @@ func (*expressionRewriter) checkTimePrecision(ft *types.FieldType) error { } func (er *expressionRewriter) useCache() bool { - return er.sctx.GetSessionVars().StmtCtx.UseCache + return er.sctx.IsUseCache() } func (er *expressionRewriter) rewriteVariable(planCtx *exprRewriterPlanCtx, v *ast.VariableExpr) { @@ -1874,7 +1875,7 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field if c.GetType().EvalType() == types.ETInt { continue // no need to refine it } - er.sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackErrorf("'%v' may be converted to INT", c.String())) + er.sctx.SetSkipPlanCache(errors.NewNoStackErrorf("'%v' may be converted to INT", c.String())) if err := expression.RemoveMutableConst(er.sctx, []expression.Expression{c}); err != nil { er.err = err return @@ -2041,7 +2042,7 @@ func (er *expressionRewriter) patternLikeOrIlikeToExpression(v *ast.PatternLikeO return } - char, col := er.sctx.GetSessionVars().GetCharsetInfo() + char, col := er.sctx.GetCharsetInfo() var function expression.Expression fieldType := &types.FieldType{} isPatternExactMatch := false diff --git a/pkg/planner/core/indexmerge_path.go b/pkg/planner/core/indexmerge_path.go index 6f9a0a16c56a3..a7995c8eaef42 100644 --- a/pkg/planner/core/indexmerge_path.go +++ b/pkg/planner/core/indexmerge_path.go @@ -1598,7 +1598,7 @@ func jsonArrayExpr2Exprs( ) ([]expression.Expression, bool) { if checkForSkipPlanCache && expression.MaybeOverOptimized4PlanCache(sctx, []expression.Expression{jsonArrayExpr}) { // skip plan cache and try to generate the best plan in this case. - sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError(jsonFuncName + " function with immutable parameters can affect index selection")) + sctx.SetSkipPlanCache(errors.NewNoStackError(jsonFuncName + " function with immutable parameters can affect index selection")) } if !expression.IsImmutableFunc(jsonArrayExpr) || jsonArrayExpr.GetType().EvalType() != types.ETJson { return nil, false diff --git a/pkg/session/session.go b/pkg/session/session.go index 2988823d40ccc..04d2a88c3e849 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2627,7 +2627,7 @@ func (s *session) GetPlanCtx() planctx.PlanContext { } // GetExprCtx returns the expression context of the session. -func (s *session) GetExprCtx() exprctx.BuildContext { +func (s *session) GetExprCtx() exprctx.ExprContext { return s.exprctx } diff --git a/pkg/sessionctx/context.go b/pkg/sessionctx/context.go index 4888177c369f7..b1e8e0334fea3 100644 --- a/pkg/sessionctx/context.go +++ b/pkg/sessionctx/context.go @@ -108,7 +108,7 @@ type Context interface { GetRestrictedSQLExecutor() sqlexec.RestrictedSQLExecutor // GetExprCtx returns the expression context of the session. - GetExprCtx() exprctx.BuildContext + GetExprCtx() exprctx.ExprContext // GetTableCtx returns the table.MutateContext GetTableCtx() tbctx.MutateContext diff --git a/pkg/table/column.go b/pkg/table/column.go index 8dd1b6e9f08d7..2878bc56d9129 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -635,7 +635,7 @@ func getColDefaultValue(ctx expression.BuildContext, col *model.ColumnInfo, defa // If the column's default value is not ZeroDatetimeStr or CurrentTimestamp, convert the default value to the current session time zone. if needChangeTimeZone { t := value.GetMysqlTime() - err = t.ConvertTimeZone(explicitTz, ctx.GetSessionVars().Location()) + err = t.ConvertTimeZone(explicitTz, ctx.GetEvalCtx().Location()) if err != nil { return value, err } @@ -661,19 +661,18 @@ func getColDefaultValueFromNil(ctx expression.BuildContext, col *model.ColumnInf // Auto increment column doesn't have default value and we should not return error. return GetZeroValue(col), nil } - vars := ctx.GetSessionVars() - sc := vars.StmtCtx + evalCtx := ctx.GetEvalCtx() var strictSQLMode bool if args != nil { strictSQLMode = args.StrictSQLMode } else { - strictSQLMode = vars.SQLMode.HasStrictMode() + strictSQLMode = evalCtx.SQLMode().HasStrictMode() } if !strictSQLMode { - sc.AppendWarning(ErrNoDefaultValue.FastGenByArgs(col.Name)) + evalCtx.AppendWarning(ErrNoDefaultValue.FastGenByArgs(col.Name)) return GetZeroValue(col), nil } - ec := sc.ErrCtx() + ec := evalCtx.ErrCtx() var err error if mysql.HasNoDefaultValueFlag(col.GetFlag()) { err = ErrNoDefaultValue.FastGenByArgs(col.Name) @@ -754,6 +753,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd virCols := chunk.NewChunkWithCapacity(virtualRetTypes, req.Capacity()) iter := chunk.NewIterator4Chunk(req) evalCtx := ectx.GetEvalCtx() + tc := evalCtx.TypeCtx() for i, idx := range virtualColumnIndex { for row := iter.Begin(); row != iter.End(); row = iter.Next() { datum, err := expCols[idx].EvalVirtualColumn(evalCtx, row) @@ -768,7 +768,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd } // Clip to zero if get negative value after cast to unsigned. - if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && ectx.GetSessionVars().StmtCtx.TypeFlags().AllowNegativeToUnsigned() { + if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && tc.Flags().AllowNegativeToUnsigned() { switch datum.Kind() { case types.KindInt64: if datum.GetInt64() < 0 { diff --git a/pkg/table/context/table.go b/pkg/table/context/table.go index a250f94aeac4d..e58bcc36b7a19 100644 --- a/pkg/table/context/table.go +++ b/pkg/table/context/table.go @@ -32,7 +32,7 @@ var _ AllocatorContext = MutateContext(nil) type MutateContext interface { AllocatorContext // GetExprCtx returns the context to build or evaluate expressions - GetExprCtx() exprctx.BuildContext + GetExprCtx() exprctx.ExprContext // Value returns the value associated with this context for key. Value(key fmt.Stringer) any // GetSessionVars returns the session variables. diff --git a/pkg/table/contextimpl/table.go b/pkg/table/contextimpl/table.go index f6402f1198ed5..165fee8ab4971 100644 --- a/pkg/table/contextimpl/table.go +++ b/pkg/table/contextimpl/table.go @@ -29,11 +29,11 @@ var _ context.AllocatorContext = &TableContextImpl{} // TableContextImpl is used to provide context for table operations. type TableContextImpl struct { sessionctx.Context - exprCtx exprctx.BuildContext + exprCtx exprctx.ExprContext } // NewTableContextImpl creates a new TableContextImpl. -func NewTableContextImpl(sctx sessionctx.Context, exprCtx exprctx.BuildContext) *TableContextImpl { +func NewTableContextImpl(sctx sessionctx.Context, exprCtx exprctx.ExprContext) *TableContextImpl { return &TableContextImpl{Context: sctx, exprCtx: exprCtx} } @@ -44,7 +44,7 @@ func (ctx *TableContextImpl) TxnRecordTempTable(tbl *model.TableInfo) tableutil. } // GetExprCtx returns the ExprContext -func (ctx *TableContextImpl) GetExprCtx() exprctx.BuildContext { +func (ctx *TableContextImpl) GetExprCtx() exprctx.ExprContext { return ctx.exprCtx } diff --git a/pkg/table/tables/partition.go b/pkg/table/tables/partition.go index c2699780e4fa3..e0d58c946729c 100644 --- a/pkg/table/tables/partition.go +++ b/pkg/table/tables/partition.go @@ -1092,8 +1092,8 @@ func (lp *ForListColumnPruning) genConstExprKey(ctx expression.BuildContext, exp if err != nil { return nil, errors.Trace(err) } - sc := ctx.GetSessionVars().StmtCtx - tc, ec := sc.TypeCtx(), sc.ErrCtx() + evalCtx := ctx.GetEvalCtx() + tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() key, err := lp.genKey(tc, ec, v) if err != nil { return nil, errors.Trace(err) diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index 9fb16cc7fdc2f..38b564b0397a4 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -2312,14 +2312,15 @@ func SetPBColumnsDefaultValue(ctx expression.BuildContext, pbColumns []*tipb.Col continue } - sessVars := ctx.GetSessionVars() + evalCtx := ctx.GetEvalCtx() d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(ctx, c) if err != nil { return err } - pbColumns[i].DefaultVal, err = tablecodec.EncodeValue(sessVars.StmtCtx.TimeZone(), nil, d) - err = sessVars.StmtCtx.HandleError(err) + pbColumns[i].DefaultVal, err = tablecodec.EncodeValue(evalCtx.Location(), nil, d) + ec := evalCtx.ErrCtx() + err = ec.HandleError(err) if err != nil { return err } diff --git a/pkg/util/mock/context.go b/pkg/util/mock/context.go index bcdbe98f66dd5..cf11298c847af 100644 --- a/pkg/util/mock/context.go +++ b/pkg/util/mock/context.go @@ -239,7 +239,7 @@ func (c *Context) GetPlanCtx() planctx.PlanContext { } // GetExprCtx returns the expression context of the session. -func (c *Context) GetExprCtx() exprctx.BuildContext { +func (c *Context) GetExprCtx() exprctx.ExprContext { return c }