From fab13afa2bc3671ade20522e78ae5d3610fe2e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Tue, 9 Apr 2024 17:00:34 +0800 Subject: [PATCH] expression: Move more methods from `SessionVars` to `BuildContext` (#52440) ref pingcap/tidb#52366 --- pkg/expression/builtin_cast.go | 14 ++++----- pkg/expression/constant_fold.go | 3 +- pkg/expression/constant_test.go | 2 +- pkg/expression/context/context.go | 15 ++++++---- pkg/expression/contextimpl/sessionctx.go | 28 ++++++++++++++++++ pkg/expression/contextimpl/sessionctx_test.go | 29 +++++++++++++++++++ pkg/expression/expression.go | 4 +-- pkg/planner/core/rule_predicate_push_down.go | 8 ++--- pkg/sessionctx/stmtctx/stmtctx.go | 1 - 9 files changed, 82 insertions(+), 22 deletions(-) diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index b623a525cfc5c..cfdb35bbfc23b 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -122,7 +122,7 @@ func (c *castAsIntFunctionClass) getFunction(ctx BuildContext, args []Expression if err != nil { return nil, err } - bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil) + bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast()) if args[0].GetType().Hybrid() || IsBinaryLiteral(args[0]) { sig = &builtinCastIntAsIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastIntAsInt) @@ -171,7 +171,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx BuildContext, args []Expressio if err != nil { return nil, err } - bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil) + bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast()) if IsBinaryLiteral(args[0]) { sig = &builtinCastRealAsRealSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastRealAsReal) @@ -226,7 +226,7 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expres if err != nil { return nil, err } - bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil) + bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast()) if IsBinaryLiteral(args[0]) { sig = &builtinCastDecimalAsDecimalSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsDecimal) @@ -2052,10 +2052,10 @@ func CanImplicitEvalReal(expr Expression) bool { // BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union // Expression. func BuildCastFunction4Union(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) { - ctx.SetValue(inUnionCastContext, struct{}{}) - defer func() { - ctx.SetValue(inUnionCastContext, nil) - }() + if !ctx.IsInUnionCast() { + ctx.SetInUnionCast(true) + defer ctx.SetInUnionCast(false) + } return BuildCastFunction(ctx, expr, tp) } diff --git a/pkg/expression/constant_fold.go b/pkg/expression/constant_fold.go index 55931403864ab..7f827c4607377 100644 --- a/pkg/expression/constant_fold.go +++ b/pkg/expression/constant_fold.go @@ -173,7 +173,6 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) { } args := x.GetArgs() - sc := ctx.GetSessionVars().StmtCtx argIsConst := make([]bool, len(args)) hasNullArg := false allConstArg := true @@ -194,7 +193,7 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) { // // NullEQ and ConcatWS are excluded, because they could have different value when the non-constant value is // 1 or NULL. For example, concat_ws(NULL, NULL) gives NULL, but concat_ws(1, NULL) gives '' - if !hasNullArg || !sc.InNullRejectCheck || x.FuncName.L == ast.NullEQ || x.FuncName.L == ast.ConcatWS { + if !hasNullArg || !ctx.IsInNullRejectCheck() || x.FuncName.L == ast.NullEQ || x.FuncName.L == ast.ConcatWS { return expr, isDeferredConst } constArgs := make([]Expression, len(args)) diff --git a/pkg/expression/constant_test.go b/pkg/expression/constant_test.go index 652779fe9bb46..be622ed7caf9e 100644 --- a/pkg/expression/constant_test.go +++ b/pkg/expression/constant_test.go @@ -236,7 +236,7 @@ func TestConstantFolding(t *testing.T) { { condition: func(ctx BuildContext) Expression { expr := newFunction(ctx, ast.ConcatWS, newColumn(0), NewNull()) - ctx.GetSessionVars().StmtCtx.InNullRejectCheck = true + ctx.SetInNullRejectCheck(true) return expr }, result: "concat_ws(cast(Column#0, var_string(20)), )", diff --git a/pkg/expression/context/context.go b/pkg/expression/context/context.go index 73a817fdf476a..679ab7df41674 100644 --- a/pkg/expression/context/context.go +++ b/pkg/expression/context/context.go @@ -15,7 +15,6 @@ package context import ( - "fmt" "time" "github.com/pingcap/tidb/pkg/errctx" @@ -86,12 +85,18 @@ type BuildContext interface { IsUseCache() bool // SetSkipPlanCache sets to skip the plan cache and records the reason. SetSkipPlanCache(reason error) + // AllocPlanColumnID allocates column id for plan. + AllocPlanColumnID() int64 + // SetInNullRejectCheck sets the flag to indicate whether the expression is in null reject check. + SetInNullRejectCheck(in bool) + // IsInNullRejectCheck returns the flag to indicate whether the expression is in null reject check. + IsInNullRejectCheck() bool + // SetInUnionCast sets the flag to indicate whether the expression is in union cast. + SetInUnionCast(in bool) + // IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero. + IsInUnionCast() bool // GetSessionVars gets the session variables. GetSessionVars() *variable.SessionVars - // Value returns the value associated with this context for key. - Value(key fmt.Stringer) any - // SetValue saves a value associated with this context for key. - SetValue(key fmt.Stringer, value any) } // ExprContext contains full context for expression building and evaluating. diff --git a/pkg/expression/contextimpl/sessionctx.go b/pkg/expression/contextimpl/sessionctx.go index 5d792f2c9b983..b6eafaa50c09f 100644 --- a/pkg/expression/contextimpl/sessionctx.go +++ b/pkg/expression/contextimpl/sessionctx.go @@ -17,6 +17,7 @@ package contextimpl import ( "context" "math" + "sync/atomic" "time" "github.com/pingcap/tidb/pkg/errctx" @@ -51,6 +52,8 @@ var _ exprctx.ExprContext = struct { type ExprCtxExtendedImpl struct { sctx sessionctx.Context *SessionEvalContext + inNullRejectCheck atomic.Bool + inUnionCast atomic.Bool } // NewExprExtendedImpl creates a new ExprCtxExtendedImpl. @@ -109,6 +112,31 @@ func (ctx *ExprCtxExtendedImpl) SetSkipPlanCache(reason error) { ctx.sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason) } +// AllocPlanColumnID allocates column id for plan. +func (ctx *ExprCtxExtendedImpl) AllocPlanColumnID() int64 { + return ctx.sctx.GetSessionVars().AllocPlanColumnID() +} + +// SetInNullRejectCheck sets whether the expression is in null reject check. +func (ctx *ExprCtxExtendedImpl) SetInNullRejectCheck(in bool) { + ctx.inNullRejectCheck.Store(in) +} + +// IsInNullRejectCheck returns whether the expression is in null reject check. +func (ctx *ExprCtxExtendedImpl) IsInNullRejectCheck() bool { + return ctx.inNullRejectCheck.Load() +} + +// SetInUnionCast sets the flag to indicate whether the expression is in union cast. +func (ctx *ExprCtxExtendedImpl) SetInUnionCast(in bool) { + ctx.inUnionCast.Store(in) +} + +// IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero. +func (ctx *ExprCtxExtendedImpl) IsInUnionCast() bool { + return ctx.inUnionCast.Load() +} + // GetWindowingUseHighPrecision determines whether to compute window operations without loss of precision. // see https://dev.mysql.com/doc/refman/8.0/en/window-function-optimization.html for more details. func (ctx *ExprCtxExtendedImpl) GetWindowingUseHighPrecision() bool { diff --git a/pkg/expression/contextimpl/sessionctx_test.go b/pkg/expression/contextimpl/sessionctx_test.go index 5e6a10f087db7..6af55175479d3 100644 --- a/pkg/expression/contextimpl/sessionctx_test.go +++ b/pkg/expression/contextimpl/sessionctx_test.go @@ -253,6 +253,7 @@ func TestSessionBuildContext(t *testing.T) { require.True(t, evalCtx.GetOptionalPropSet().IsFull()) require.Same(t, ctx, evalCtx.Sctx()) + // charset and collation vars := ctx.GetSessionVars() err := vars.SetSystemVar("character_set_connection", "gbk") require.NoError(t, err) @@ -265,17 +266,45 @@ func TestSessionBuildContext(t *testing.T) { require.Equal(t, "gbk_chinese_ci", collate) require.Equal(t, "utf8mb4_0900_ai_ci", impl.GetDefaultCollationForUTF8MB4()) + // SysdateIsNow vars.SysdateIsNow = true require.True(t, impl.GetSysdateIsNow()) + // NoopFuncsMode vars.NoopFuncsMode = 2 require.Equal(t, 2, impl.GetNoopFuncsMode()) + // Rng vars.Rng = mathutil.NewWithSeed(123) require.Same(t, vars.Rng, impl.Rng()) + // PlanCache vars.StmtCtx.UseCache = true require.True(t, impl.IsUseCache()) impl.SetSkipPlanCache(errors.New("mockReason")) require.False(t, impl.IsUseCache()) + + // Alloc column id + prevID := vars.PlanColumnID.Load() + colID := impl.AllocPlanColumnID() + require.Equal(t, colID, prevID+1) + colID = impl.AllocPlanColumnID() + require.Equal(t, colID, prevID+2) + vars.AllocPlanColumnID() + colID = impl.AllocPlanColumnID() + require.Equal(t, colID, prevID+4) + + // InNullRejectCheck + require.False(t, impl.IsInNullRejectCheck()) + impl.SetInNullRejectCheck(true) + require.True(t, impl.IsInNullRejectCheck()) + impl.SetInNullRejectCheck(false) + require.False(t, impl.IsInNullRejectCheck()) + + // InUnionCast + require.False(t, impl.IsInUnionCast()) + impl.SetInUnionCast(true) + require.True(t, impl.IsInUnionCast()) + impl.SetInUnionCast(false) + require.False(t, impl.IsInUnionCast()) } diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index ca049f70c4747..3dd6e60d3479b 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -865,7 +865,7 @@ func EvaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Exp if MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) { ctx.SetSkipPlanCache(errors.NewNoStackError("%v affects null check")) } - if ctx.GetSessionVars().StmtCtx.InNullRejectCheck { + if ctx.IsInNullRejectCheck() { expr, _ = evaluateExprWithNullInNullRejectCheck(ctx, schema, expr) return expr } @@ -1022,7 +1022,7 @@ func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr, newCol := &Column{ RetType: col.FieldType.Clone(), ID: col.ID, - UniqueID: ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: ctx.AllocPlanColumnID(), Index: col.Offset, OrigName: names[i].String(), IsHidden: col.Hidden, diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index 9561108586257..1e4fec21e434e 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -428,10 +428,10 @@ func isNullRejected(ctx PlanContext, schema *expression.Schema, expr expression. return false } sc := ctx.GetSessionVars().StmtCtx - sc.InNullRejectCheck = true - defer func() { - sc.InNullRejectCheck = false - }() + if !exprCtx.IsInNullRejectCheck() { + exprCtx.SetInNullRejectCheck(true) + defer exprCtx.SetInNullRejectCheck(false) + } for _, cond := range expression.SplitCNFItems(expr) { if isNullRejectedSpecially(ctx, schema, expr) { return true diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index da6d1d7217a68..e7b6c54ad5e4a 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -198,7 +198,6 @@ type StatementContext struct { ForcePlanCache bool // force the optimizer to use plan cache even if there is risky optimization, see #49736. CacheType PlanCacheType BatchCheck bool - InNullRejectCheck bool IgnoreExplainIDSuffix bool MultiSchemaInfo *model.MultiSchemaInfo // If the select statement was like 'select * from t as of timestamp ...' or in a stale read transaction