Skip to content

Commit

Permalink
expression: add some methods in BuildContext to read fields in Sess…
Browse files Browse the repository at this point in the history
…ionVars before
  • Loading branch information
lcwangchao committed Apr 7, 2024
1 parent 482ce59 commit 97c8879
Show file tree
Hide file tree
Showing 50 changed files with 315 additions and 186 deletions.
29 changes: 19 additions & 10 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/pkg/ddl/resourcegroup"
ddlutil "github.com/pingcap/tidb/pkg/ddl/util"
rg "github.com/pingcap/tidb/pkg/domain/resourcegroup"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
Expand Down Expand Up @@ -2267,7 +2268,8 @@ func BuildTableInfo(
if len(hiddenCols) > 0 {
AddIndexColumnFlag(tbInfo, idxInfo)
}
_, err = validateCommentLength(ctx.GetSessionVars(), idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment)
sessionVars := ctx.GetSessionVars()
_, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -2621,7 +2623,8 @@ func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCh
return nil, errors.Trace(err)
}

if _, err = validateCommentLength(ctx.GetSessionVars(), tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil {
sessionVars := ctx.GetSessionVars()
if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil {
return nil, errors.Trace(err)
}

Expand Down Expand Up @@ -5562,7 +5565,9 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col
if col.Comment, err = value.ToString(); err != nil {
return errors.Trace(err)
}
col.Comment, err = validateCommentLength(ctx.GetSessionVars(), col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment)

sessionVars := ctx.GetSessionVars()
col.Comment, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment)
return errors.Trace(err)
}

Expand Down Expand Up @@ -6370,7 +6375,8 @@ func (d *ddl) AlterTableComment(ctx sessionctx.Context, ident ast.Ident, spec *a
if err != nil {
return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name))
}
if _, err = validateCommentLength(ctx.GetSessionVars(), ident.Name.L, &spec.Comment, dbterror.ErrTooLongTableComment); err != nil {
sessionVars := ctx.GetSessionVars()
if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, ident.Name.L, &spec.Comment, dbterror.ErrTooLongTableComment); err != nil {
return errors.Trace(err)
}

Expand Down Expand Up @@ -7413,7 +7419,8 @@ func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName m

// May be truncate comment here, when index comment too long and sql_mode is't strict.
if indexOption != nil {
if _, err = validateCommentLength(ctx.GetSessionVars(), indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil {
sessionVars := ctx.GetSessionVars()
if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil {
return errors.Trace(err)
}
}
Expand Down Expand Up @@ -7665,7 +7672,8 @@ func (d *ddl) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.Inde
}
// May be truncate comment here, when index comment too long and sql_mode is't strict.
if indexOption != nil {
if _, err = validateCommentLength(ctx.GetSessionVars(), indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil {
sessionVars := ctx.GetSessionVars()
if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil {
return errors.Trace(err)
}
}
Expand Down Expand Up @@ -8069,7 +8077,7 @@ func isDroppableColumn(tblInfo *model.TableInfo, colName model.CIStr) error {
// validateCommentLength checks comment length of table, column, or index
// If comment length is more than the standard length truncate it
// and store the comment length upto the standard comment length size.
func validateCommentLength(vars *variable.SessionVars, name string, comment *string, errTooLongComment *terror.Error) (string, error) {
func validateCommentLength(ec errctx.Context, sqlMode mysql.SQLMode, name string, comment *string, errTooLongComment *terror.Error) (string, error) {
if comment == nil {
return "", nil
}
Expand All @@ -8086,11 +8094,11 @@ func validateCommentLength(vars *variable.SessionVars, name string, comment *str
}
if len(*comment) > maxLen {
err := errTooLongComment.GenWithStackByArgs(name, maxLen)
if vars.SQLMode.HasStrictMode() {
if sqlMode.HasStrictMode() {
// may be treated like an error.
return "", err
}
vars.StmtCtx.AppendWarning(err)
ec.AppendWarning(err)
*comment = (*comment)[:maxLen]
}
return *comment, nil
Expand Down Expand Up @@ -8231,7 +8239,8 @@ func checkAndGetColumnsTypeAndValuesMatch(ctx expression.BuildContext, colTypes
return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs()
}
}
newVal, err := val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &colType)
evalCtx := ctx.GetEvalCtx()
newVal, err := val.ConvertTo(evalCtx.TypeCtx(), &colType)
if err != nil {
return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs()
}
Expand Down
18 changes: 11 additions & 7 deletions pkg/ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,12 +803,13 @@ func comparePartitionAstAndModel(ctx expression.BuildContext, pAst *ast.Partitio
return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: number of partitions generated != partition defined (%d != %d)", len(a), len(m))
}

evalCtx := ctx.GetEvalCtx()
evalFn := func(expr ast.ExprNode) (types.Datum, error) {
val, err := expression.EvalSimpleAst(ctx, ast.NewValueExpr(expr, "", ""))
if err != nil || partCol == nil {
return val, err
}
return val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &partCol.FieldType)
return val.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType)
}
for i := range pAst.Definitions {
// Allow options to differ! (like Placement Rules)
Expand Down Expand Up @@ -840,7 +841,7 @@ func comparePartitionAstAndModel(ctx expression.BuildContext, pAst *ast.Partitio
if err != nil {
return err
}
cmp, err := lessThanVal.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &generatedExprVal, collate.GetBinaryCollator())
cmp, err := lessThanVal.Compare(evalCtx.TypeCtx(), &generatedExprVal, collate.GetBinaryCollator())
if err != nil {
return err
}
Expand Down Expand Up @@ -1093,8 +1094,9 @@ func GeneratePartDefsFromInterval(ctx expression.BuildContext, tp ast.AlterTable
if err != nil {
return err
}
evalCtx := ctx.GetEvalCtx()
if partCol != nil {
lastVal, err = lastVal.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &partCol.FieldType)
lastVal, err = lastVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType)
if err != nil {
return err
}
Expand Down Expand Up @@ -1143,12 +1145,12 @@ func GeneratePartDefsFromInterval(ctx expression.BuildContext, tp ast.AlterTable
return err
}
if partCol != nil {
currVal, err = currVal.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &partCol.FieldType)
currVal, err = currVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType)
if err != nil {
return err
}
}
cmp, err := currVal.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &lastVal, collate.GetBinaryCollator())
cmp, err := currVal.Compare(evalCtx.TypeCtx(), &lastVal, collate.GetBinaryCollator())
if err != nil {
return err
}
Expand Down Expand Up @@ -1416,7 +1418,8 @@ func buildRangePartitionDefinitions(ctx expression.BuildContext, defs []*ast.Par
}
}
comment, _ := def.Comment()
comment, err := validateCommentLength(ctx.GetSessionVars(), def.Name.L, &comment, dbterror.ErrTooLongTablePartitionComment)
evalCtx := ctx.GetEvalCtx()
comment, err := validateCommentLength(evalCtx.ErrCtx(), evalCtx.SQLMode(), def.Name.L, &comment, dbterror.ErrTooLongTablePartitionComment)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1499,7 +1502,8 @@ func checkPartitionValuesIsInt(ctx expression.BuildContext, defName any, exprs [
return dbterror.ErrValuesIsNotIntType.GenWithStackByArgs(defName)
}

_, err = val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), tp)
evalCtx := ctx.GetEvalCtx()
_, err = val.ConvertTo(evalCtx.TypeCtx(), tp)
if err != nil && !types.ErrOverflow.Equal(err) {
return dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs()
}
Expand Down
1 change: 0 additions & 1 deletion pkg/executor/aggfuncs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ go_library(
"//pkg/parser/charset",
"//pkg/parser/mysql",
"//pkg/planner/util",
"//pkg/sessionctx/variable",
"//pkg/types",
"//pkg/util/chunk",
"//pkg/util/codec",
Expand Down
23 changes: 6 additions & 17 deletions pkg/executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
package aggfuncs

import (
"context"
"fmt"
"strconv"

"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/aggregation"
exprctx "github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/collate"
Expand All @@ -33,7 +30,7 @@ import (
)

// AggFuncBuildContext is used to build aggregation functions.
type AggFuncBuildContext = exprctx.BuildContext
type AggFuncBuildContext = exprctx.AggFuncBuildContext

// Build is used to build a specific AggFunc implementation according to the
// input aggFuncDesc.
Expand Down Expand Up @@ -287,7 +284,7 @@ func buildSum(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncDesc, ord
if aggFuncDesc.HasDistinct {
return &sum4DistinctFloat64{base}
}
if ctx.GetSessionVars().WindowingUseHighPrecision {
if ctx.GetWindowingUseHighPrecision() {
return &sum4Float64HighPrecision{baseSum4Float64{base}}
}
return &sum4Float64{baseSum4Float64{base}}
Expand Down Expand Up @@ -321,7 +318,7 @@ func buildAvg(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncDesc, ord
if aggFuncDesc.HasDistinct {
return &avgOriginal4DistinctFloat64{base}
}
if ctx.GetSessionVars().WindowingUseHighPrecision {
if ctx.GetWindowingUseHighPrecision() {
return &avgOriginal4Float64HighPrecision{baseAvgFloat64{base}}
}
return &avgOriginal4Float64{avgOriginal4Float64HighPrecision{baseAvgFloat64{base}}}
Expand Down Expand Up @@ -477,16 +474,7 @@ func buildGroupConcat(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncD
if err != nil {
panic(fmt.Sprintf("Error happened when buildGroupConcat: %s", err.Error()))
}
var s string
s, err = ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.GroupConcatMaxLen)
if err != nil {
panic(fmt.Sprintf("Error happened when buildGroupConcat: no system variable named '%s'", variable.GroupConcatMaxLen))
}
maxLen, err := strconv.ParseUint(s, 10, 64)
// Should never happen
if err != nil {
panic(fmt.Sprintf("Error happened when buildGroupConcat: %s", err.Error()))
}
maxLen := ctx.GetGroupConcatMaxLen()
var truncated int32
base := baseGroupConcat4String{
baseAggFunc: baseAggFunc{
Expand Down Expand Up @@ -719,7 +707,8 @@ func buildLeadLag(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncDesc,
if len(aggFuncDesc.Args) == 3 {
defaultExpr = aggFuncDesc.Args[2]
if et, ok := defaultExpr.(*expression.Constant); ok {
res, err1 := et.Value.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), aggFuncDesc.RetTp)
evalCtx := ctx.GetEvalCtx()
res, err1 := et.Value.ConvertTo(evalCtx.TypeCtx(), aggFuncDesc.RetTp)
if err1 == nil {
defaultExpr = &expression.Constant{Value: res, RetType: aggFuncDesc.RetTp}
}
Expand Down
1 change: 0 additions & 1 deletion pkg/expression/aggregation/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ go_library(
"//pkg/parser/terror",
"//pkg/planner/util",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
"//pkg/types",
"//pkg/util/chunk",
"//pkg/util/codec",
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (a *baseFuncDesc) TypeInfer(ctx expression.BuildContext) error {
case ast.AggFuncSum:
a.typeInfer4Sum()
case ast.AggFuncAvg:
a.typeInfer4Avg(ctx.GetSessionVars().GetDivPrecisionIncrement())
a.typeInfer4Avg(ctx.GetEvalCtx().GetDivPrecisionIncrement())
case ast.AggFuncGroupConcat:
a.typeInfer4GroupConcat(ctx)
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow,
Expand Down
19 changes: 2 additions & 17 deletions pkg/expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@ package aggregation

import (
"bytes"
"context"
"fmt"
"math"
"strconv"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/size"
Expand Down Expand Up @@ -218,7 +214,7 @@ func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx expression.BuildContext, sche
}

// GetAggFunc gets an evaluator according to the aggregation function signature.
func (a *AggFuncDesc) GetAggFunc(ctx expression.BuildContext) Aggregation {
func (a *AggFuncDesc) GetAggFunc(ctx expression.AggFuncBuildContext) Aggregation {
aggFunc := aggFunction{AggFuncDesc: a}
switch a.Name {
case ast.AggFuncSum:
Expand All @@ -228,18 +224,7 @@ func (a *AggFuncDesc) GetAggFunc(ctx expression.BuildContext) Aggregation {
case ast.AggFuncAvg:
return &avgFunction{aggFunction: aggFunc}
case ast.AggFuncGroupConcat:
var s string
var err error
var maxLen uint64
s, err = ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.GroupConcatMaxLen)
if err != nil {
panic(fmt.Sprintf("Error happened when GetAggFunc: no system variable named '%s'", variable.GroupConcatMaxLen))
}
maxLen, err = strconv.ParseUint(s, 10, 64)
if err != nil {
panic(fmt.Sprintf("Error happened when GetAggFunc: illegal value for system variable named '%s'", variable.GroupConcatMaxLen))
}
return &concatFunction{aggFunction: aggFunc, maxLen: maxLen}
return &concatFunction{aggFunction: aggFunc, maxLen: ctx.GetGroupConcatMaxLen()}
case ast.AggFuncMax:
return &maxMinFunction{aggFunction: aggFunc, isMax: true, ctor: collate.GetCollator(a.Args[0].GetType().GetCollate())}
case ast.AggFuncMin:
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import (
)

type benchHelper struct {
ctx BuildContext
ctx *mock.Context
exprs []Expression

inputTypes []*types.FieldType
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx BuildContext, args []Expr
if err != nil {
return nil, err
}
if (mysql.HasUnsignedFlag(args[0].GetType().GetFlag()) || mysql.HasUnsignedFlag(args[1].GetType().GetFlag())) && !ctx.GetSessionVars().SQLMode.HasNoUnsignedSubtractionMode() {
if (mysql.HasUnsignedFlag(args[0].GetType().GetFlag()) || mysql.HasUnsignedFlag(args[1].GetType().GetFlag())) && !ctx.GetEvalCtx().SQLMode().HasNoUnsignedSubtractionMode() {
bf.tp.AddFlag(mysql.UnsignedFlag)
}
sig := &builtinArithmeticMinusIntSig{baseBuiltinFunc: bf}
Expand Down Expand Up @@ -661,7 +661,7 @@ func (c *arithmeticDivideFunctionClass) getFunction(ctx BuildContext, args []Exp
if err != nil {
return nil, err
}
c.setType4DivDecimal(bf.tp, lhsTp, rhsTp, ctx.GetSessionVars().GetDivPrecisionIncrement())
c.setType4DivDecimal(bf.tp, lhsTp, rhsTp, ctx.GetEvalCtx().GetDivPrecisionIncrement())
sig := &builtinArithmeticDivideDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_DivideDecimal)
return sig, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2270,7 +2270,7 @@ func WrapWithCastAsString(ctx BuildContext, expr Expression) Expression {
tp.SetCharset(charset.CharsetBin)
tp.SetCollate(charset.CollationBin)
} else {
charset, collate := ctx.GetSessionVars().GetCharsetInfo()
charset, collate := ctx.GetCharsetInfo()
tp.SetCharset(charset)
tp.SetCollate(collate)
}
Expand Down
Loading

0 comments on commit 97c8879

Please sign in to comment.