Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stmtctx, *: change TypeCtx field to a private field #47742

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
return str, false, err
}
// For other kind of fields (e.g. INT), we supply its integer as string value.
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx)
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx())
if err != nil {
return nil, false, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue
if err != nil && gCol.FieldType.IsArray() {
return nil, completeError(tbl, gCol.Offset, rowIdx, err)
}
if e.Ctx().GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) != nil {
if e.Ctx().GetSessionVars().StmtCtx.HandleTruncate(err) != nil {
return nil, err
}
row[colIdx], err = table.CastValue(e.Ctx(), val, gCol.ToInfo(), false, false)
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func (af *avgFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEv
func (af *avgFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) (err error) {
switch af.Mode {
case Partial1Mode, CompleteMode:
err = af.updateSum(sc.TypeCtx, evalCtx, row)
err = af.updateSum(sc.TypeCtx(), evalCtx, row)
case Partial2Mode, FinalMode:
err = af.updateAvg(sc.TypeCtx, evalCtx, row)
err = af.updateAvg(sc.TypeCtx(), evalCtx, row)
case DedupMode:
panic("DedupMode is not supported now.")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type sumFunction struct {

// Update implements Aggregation interface.
func (sf *sumFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) error {
return sf.updateSum(sc.TypeCtx, evalCtx, row)
return sf.updateSum(sc.TypeCtx(), evalCtx, row)
}

// GetResult implements Aggregation interface.
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row chunk.Row) (*types.M
return c, true, handleDivisionByZeroError(s.ctx)
} else if err == types.ErrTruncated {
sc := s.ctx.GetSessionVars().StmtCtx
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
} else if err == nil {
_, frac := c.PrecisionAndFrac()
if frac < s.baseBuiltinFunc.tp.GetDecimal() {
Expand Down Expand Up @@ -846,7 +846,7 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64
return 0, true, handleDivisionByZeroError(s.ctx)
}
if err == types.ErrTruncated {
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
}
if err == types.ErrOverflow {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(input *chunk.Chunk, r
result.SetNull(i, true)
continue
} else if err == types.ErrTruncated {
if err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil {
if err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil {
return err
}
} else if err == nil {
Expand Down Expand Up @@ -617,7 +617,7 @@ func (b *builtinArithmeticIntDivideDecimalSig) vecEvalInt(input *chunk.Chunk, re
continue
}
if err == types.ErrTruncated {
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
} else if err == types.ErrOverflow {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)
err = sc.HandleOverflow(newErr, newErr)
Expand Down
44 changes: 22 additions & 22 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ
if item.TypeCode != types.JSONTypeCodeString {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc.TypeCtx, false)
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc.TypeCtx(), false)
}
case types.ETInt:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
Expand All @@ -552,7 +552,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ
if item.TypeCode != types.JSONTypeCodeFloat64 && item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ConvertJSONToFloat(sc.TypeCtx, item)
return types.ConvertJSONToFloat(sc.TypeCtx(), item)
}
case types.ETDatetime:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
Expand Down Expand Up @@ -730,7 +730,7 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul
if tp.GetType() == mysql.TypeYear && res == "0" {
res = "0000"
}
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -790,7 +790,7 @@ func (b *builtinCastIntAsDurationSig) evalDuration(row chunk.Row) (res types.Dur
err = b.ctx.GetSessionVars().StmtCtx.HandleOverflow(err, err)
}
if types.ErrTruncatedWrongVal.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
}
return res, true, err
}
Expand Down Expand Up @@ -1045,7 +1045,7 @@ func (b *builtinCastRealAsStringSig) evalString(row chunk.Row) (res string, isNu
// If we strconv.FormatFloat the value with 64bits, the result is incorrect!
bits = 32
}
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1102,7 +1102,7 @@ func (b *builtinCastRealAsDurationSig) evalDuration(row chunk.Row) (res types.Du
res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(val, 'f', -1, 64), b.tp.GetDecimal())
if err != nil {
if types.ErrTruncatedWrongVal.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
// ErrTruncatedWrongVal needs to be considered NULL.
return res, true, err
}
Expand Down Expand Up @@ -1191,7 +1191,7 @@ func (b *builtinCastDecimalAsStringSig) evalString(row chunk.Row) (res string, i
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1279,7 +1279,7 @@ func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types
}
res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), b.tp.GetDecimal())
if types.ErrTruncatedWrongVal.Equal(err) {
err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
// ErrTruncatedWrongVal needs to be considered NULL.
return res, true, err
}
Expand All @@ -1301,7 +1301,7 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is
if isNull || err != nil {
return res, isNull, err
}
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1366,7 +1366,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo
var ures uint64
sc := b.ctx.GetSessionVars().StmtCtx
if !isNegative {
ures, err = types.StrToUint(sc.TypeCtx, val, true)
ures, err = types.StrToUint(sc.TypeCtx(), val, true)
res = int64(ures)

if err == nil && !mysql.HasUnsignedFlag(b.tp.GetFlag()) && ures > uint64(math.MaxInt64) {
Expand All @@ -1375,7 +1375,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo
} else if b.inUnion && mysql.HasUnsignedFlag(b.tp.GetFlag()) {
res = 0
} else {
res, err = types.StrToInt(sc.TypeCtx, val, true)
res, err = types.StrToInt(sc.TypeCtx(), val, true)
if err == nil && mysql.HasUnsignedFlag(b.tp.GetFlag()) {
// If overflow, don't append this warnings
sc.AppendWarning(types.ErrCastNegIntAsUnsigned)
Expand Down Expand Up @@ -1411,7 +1411,7 @@ func (b *builtinCastStringAsRealSig) evalReal(row chunk.Row) (res float64, isNul
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.StrToFloat(sc.TypeCtx, val, true)
res, err = types.StrToFloat(sc.TypeCtx(), val, true)
if err != nil {
return 0, false, err
}
Expand Down Expand Up @@ -1449,7 +1449,7 @@ func (b *builtinCastStringAsDecimalSig) evalDecimal(row chunk.Row) (res *types.M
if err == types.ErrTruncated {
err = types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", []byte(val))
}
err = sc.TypeCtx.HandleTruncate(err)
err = sc.HandleTruncate(err)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1506,7 +1506,7 @@ func (b *builtinCastStringAsDurationSig) evalDuration(row chunk.Row) (res types.
res, isNull, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, val, b.tp.GetDecimal())
if types.ErrTruncatedWrongVal.Equal(err) {
sc := b.ctx.GetSessionVars().StmtCtx
err = sc.TypeCtx.HandleTruncate(err)
err = sc.HandleTruncate(err)
}
return res, isNull, err
}
Expand Down Expand Up @@ -1619,7 +1619,7 @@ func (b *builtinCastTimeAsStringSig) evalString(row chunk.Row) (res string, isNu
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1752,7 +1752,7 @@ func (b *builtinCastDurationAsStringSig) evalString(row chunk.Row) (res string,
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1854,7 +1854,7 @@ func (b *builtinCastJSONAsRealSig) evalReal(row chunk.Row) (res float64, isNull
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ConvertJSONToFloat(sc.TypeCtx, val)
res, err = types.ConvertJSONToFloat(sc.TypeCtx(), val)
return
}

Expand All @@ -1874,7 +1874,7 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD
return res, isNull, err
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ConvertJSONToDecimal(sc.TypeCtx, val)
res, err = types.ConvertJSONToDecimal(sc.TypeCtx(), val)
if err != nil {
return res, false, err
}
Expand All @@ -1897,7 +1897,7 @@ func (b *builtinCastJSONAsStringSig) evalString(row chunk.Row) (res string, isNu
if isNull || err != nil {
return res, isNull, err
}
s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false)
s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false)
if err != nil {
return res, false, err
}
Expand Down Expand Up @@ -1960,7 +1960,7 @@ func (b *builtinCastJSONAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNu
return res, isNull, err
default:
err = types.ErrTruncatedWrongVal.GenWithStackByArgs(types.TypeStr(b.tp.GetType()), val.String())
return res, true, b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err)
return res, true, b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err)
}
}

Expand Down Expand Up @@ -2002,12 +2002,12 @@ func (b *builtinCastJSONAsDurationSig) evalDuration(row chunk.Row) (res types.Du
res, _, err = types.ParseDuration(stmtCtx, s, b.tp.GetDecimal())
if types.ErrTruncatedWrongVal.Equal(err) {
sc := b.ctx.GetSessionVars().StmtCtx
err = sc.TypeCtx.HandleTruncate(err)
err = sc.HandleTruncate(err)
}
return res, isNull, err
default:
err = types.ErrTruncatedWrongVal.GenWithStackByArgs("TIME", val.String())
return res, true, stmtCtx.TypeCtx.HandleTruncate(err)
return res, true, stmtCtx.HandleTruncate(err)
}
}

Expand Down
Loading