Skip to content

Commit

Permalink
expression: Move more methods from SessionVars to BuildContext (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Apr 9, 2024
1 parent 38f665a commit fab13af
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 22 deletions.
14 changes: 7 additions & 7 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (c *castAsIntFunctionClass) getFunction(ctx BuildContext, args []Expression
if err != nil {
return nil, err
}
bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil)
bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast())
if args[0].GetType().Hybrid() || IsBinaryLiteral(args[0]) {
sig = &builtinCastIntAsIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsInt)
Expand Down Expand Up @@ -171,7 +171,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx BuildContext, args []Expressio
if err != nil {
return nil, err
}
bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil)
bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast())
if IsBinaryLiteral(args[0]) {
sig = &builtinCastRealAsRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastRealAsReal)
Expand Down Expand Up @@ -226,7 +226,7 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expres
if err != nil {
return nil, err
}
bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil)
bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast())
if IsBinaryLiteral(args[0]) {
sig = &builtinCastDecimalAsDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsDecimal)
Expand Down Expand Up @@ -2052,10 +2052,10 @@ func CanImplicitEvalReal(expr Expression) bool {
// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union
// Expression.
func BuildCastFunction4Union(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) {
ctx.SetValue(inUnionCastContext, struct{}{})
defer func() {
ctx.SetValue(inUnionCastContext, nil)
}()
if !ctx.IsInUnionCast() {
ctx.SetInUnionCast(true)
defer ctx.SetInUnionCast(false)
}
return BuildCastFunction(ctx, expr, tp)
}

Expand Down
3 changes: 1 addition & 2 deletions pkg/expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) {
}

args := x.GetArgs()
sc := ctx.GetSessionVars().StmtCtx
argIsConst := make([]bool, len(args))
hasNullArg := false
allConstArg := true
Expand All @@ -194,7 +193,7 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) {
//
// NullEQ and ConcatWS are excluded, because they could have different value when the non-constant value is
// 1 or NULL. For example, concat_ws(NULL, NULL) gives NULL, but concat_ws(1, NULL) gives ''
if !hasNullArg || !sc.InNullRejectCheck || x.FuncName.L == ast.NullEQ || x.FuncName.L == ast.ConcatWS {
if !hasNullArg || !ctx.IsInNullRejectCheck() || x.FuncName.L == ast.NullEQ || x.FuncName.L == ast.ConcatWS {
return expr, isDeferredConst
}
constArgs := make([]Expression, len(args))
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func TestConstantFolding(t *testing.T) {
{
condition: func(ctx BuildContext) Expression {
expr := newFunction(ctx, ast.ConcatWS, newColumn(0), NewNull())
ctx.GetSessionVars().StmtCtx.InNullRejectCheck = true
ctx.SetInNullRejectCheck(true)
return expr
},
result: "concat_ws(cast(Column#0, var_string(20)), <nil>)",
Expand Down
15 changes: 10 additions & 5 deletions pkg/expression/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package context

import (
"fmt"
"time"

"github.com/pingcap/tidb/pkg/errctx"
Expand Down Expand Up @@ -86,12 +85,18 @@ type BuildContext interface {
IsUseCache() bool
// SetSkipPlanCache sets to skip the plan cache and records the reason.
SetSkipPlanCache(reason error)
// AllocPlanColumnID allocates column id for plan.
AllocPlanColumnID() int64
// SetInNullRejectCheck sets the flag to indicate whether the expression is in null reject check.
SetInNullRejectCheck(in bool)
// IsInNullRejectCheck returns the flag to indicate whether the expression is in null reject check.
IsInNullRejectCheck() bool
// SetInUnionCast sets the flag to indicate whether the expression is in union cast.
SetInUnionCast(in bool)
// IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero.
IsInUnionCast() bool
// GetSessionVars gets the session variables.
GetSessionVars() *variable.SessionVars
// Value returns the value associated with this context for key.
Value(key fmt.Stringer) any
// SetValue saves a value associated with this context for key.
SetValue(key fmt.Stringer, value any)
}

// ExprContext contains full context for expression building and evaluating.
Expand Down
28 changes: 28 additions & 0 deletions pkg/expression/contextimpl/sessionctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package contextimpl
import (
"context"
"math"
"sync/atomic"
"time"

"github.com/pingcap/tidb/pkg/errctx"
Expand Down Expand Up @@ -51,6 +52,8 @@ var _ exprctx.ExprContext = struct {
type ExprCtxExtendedImpl struct {
sctx sessionctx.Context
*SessionEvalContext
inNullRejectCheck atomic.Bool
inUnionCast atomic.Bool
}

// NewExprExtendedImpl creates a new ExprCtxExtendedImpl.
Expand Down Expand Up @@ -109,6 +112,31 @@ func (ctx *ExprCtxExtendedImpl) SetSkipPlanCache(reason error) {
ctx.sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason)
}

// AllocPlanColumnID allocates column id for plan.
func (ctx *ExprCtxExtendedImpl) AllocPlanColumnID() int64 {
return ctx.sctx.GetSessionVars().AllocPlanColumnID()
}

// SetInNullRejectCheck sets whether the expression is in null reject check.
func (ctx *ExprCtxExtendedImpl) SetInNullRejectCheck(in bool) {
ctx.inNullRejectCheck.Store(in)
}

// IsInNullRejectCheck returns whether the expression is in null reject check.
func (ctx *ExprCtxExtendedImpl) IsInNullRejectCheck() bool {
return ctx.inNullRejectCheck.Load()
}

// SetInUnionCast sets the flag to indicate whether the expression is in union cast.
func (ctx *ExprCtxExtendedImpl) SetInUnionCast(in bool) {
ctx.inUnionCast.Store(in)
}

// IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero.
func (ctx *ExprCtxExtendedImpl) IsInUnionCast() bool {
return ctx.inUnionCast.Load()
}

// GetWindowingUseHighPrecision determines whether to compute window operations without loss of precision.
// see https://dev.mysql.com/doc/refman/8.0/en/window-function-optimization.html for more details.
func (ctx *ExprCtxExtendedImpl) GetWindowingUseHighPrecision() bool {
Expand Down
29 changes: 29 additions & 0 deletions pkg/expression/contextimpl/sessionctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ func TestSessionBuildContext(t *testing.T) {
require.True(t, evalCtx.GetOptionalPropSet().IsFull())
require.Same(t, ctx, evalCtx.Sctx())

// charset and collation
vars := ctx.GetSessionVars()
err := vars.SetSystemVar("character_set_connection", "gbk")
require.NoError(t, err)
Expand All @@ -265,17 +266,45 @@ func TestSessionBuildContext(t *testing.T) {
require.Equal(t, "gbk_chinese_ci", collate)
require.Equal(t, "utf8mb4_0900_ai_ci", impl.GetDefaultCollationForUTF8MB4())

// SysdateIsNow
vars.SysdateIsNow = true
require.True(t, impl.GetSysdateIsNow())

// NoopFuncsMode
vars.NoopFuncsMode = 2
require.Equal(t, 2, impl.GetNoopFuncsMode())

// Rng
vars.Rng = mathutil.NewWithSeed(123)
require.Same(t, vars.Rng, impl.Rng())

// PlanCache
vars.StmtCtx.UseCache = true
require.True(t, impl.IsUseCache())
impl.SetSkipPlanCache(errors.New("mockReason"))
require.False(t, impl.IsUseCache())

// Alloc column id
prevID := vars.PlanColumnID.Load()
colID := impl.AllocPlanColumnID()
require.Equal(t, colID, prevID+1)
colID = impl.AllocPlanColumnID()
require.Equal(t, colID, prevID+2)
vars.AllocPlanColumnID()
colID = impl.AllocPlanColumnID()
require.Equal(t, colID, prevID+4)

// InNullRejectCheck
require.False(t, impl.IsInNullRejectCheck())
impl.SetInNullRejectCheck(true)
require.True(t, impl.IsInNullRejectCheck())
impl.SetInNullRejectCheck(false)
require.False(t, impl.IsInNullRejectCheck())

// InUnionCast
require.False(t, impl.IsInUnionCast())
impl.SetInUnionCast(true)
require.True(t, impl.IsInUnionCast())
impl.SetInUnionCast(false)
require.False(t, impl.IsInUnionCast())
}
4 changes: 2 additions & 2 deletions pkg/expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ func EvaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Exp
if MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) {
ctx.SetSkipPlanCache(errors.NewNoStackError("%v affects null check"))
}
if ctx.GetSessionVars().StmtCtx.InNullRejectCheck {
if ctx.IsInNullRejectCheck() {
expr, _ = evaluateExprWithNullInNullRejectCheck(ctx, schema, expr)
return expr
}
Expand Down Expand Up @@ -1022,7 +1022,7 @@ func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr,
newCol := &Column{
RetType: col.FieldType.Clone(),
ID: col.ID,
UniqueID: ctx.GetSessionVars().AllocPlanColumnID(),
UniqueID: ctx.AllocPlanColumnID(),
Index: col.Offset,
OrigName: names[i].String(),
IsHidden: col.Hidden,
Expand Down
8 changes: 4 additions & 4 deletions pkg/planner/core/rule_predicate_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,10 @@ func isNullRejected(ctx PlanContext, schema *expression.Schema, expr expression.
return false
}
sc := ctx.GetSessionVars().StmtCtx
sc.InNullRejectCheck = true
defer func() {
sc.InNullRejectCheck = false
}()
if !exprCtx.IsInNullRejectCheck() {
exprCtx.SetInNullRejectCheck(true)
defer exprCtx.SetInNullRejectCheck(false)
}
for _, cond := range expression.SplitCNFItems(expr) {
if isNullRejectedSpecially(ctx, schema, expr) {
return true
Expand Down
1 change: 0 additions & 1 deletion pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ type StatementContext struct {
ForcePlanCache bool // force the optimizer to use plan cache even if there is risky optimization, see #49736.
CacheType PlanCacheType
BatchCheck bool
InNullRejectCheck bool
IgnoreExplainIDSuffix bool
MultiSchemaInfo *model.MultiSchemaInfo
// If the select statement was like 'select * from t as of timestamp ...' or in a stale read transaction
Expand Down

0 comments on commit fab13af

Please sign in to comment.