diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 9aae98f0a323d..ed37d96d4a219 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -1352,7 +1352,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu return str, false, err } // For other kind of fields (e.g. INT), we supply its integer as string value. - value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) + value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) if err != nil { return nil, false, err } diff --git a/pkg/executor/insert_common.go b/pkg/executor/insert_common.go index 09ddfd5500206..f7f1c48d525ae 100644 --- a/pkg/executor/insert_common.go +++ b/pkg/executor/insert_common.go @@ -706,7 +706,7 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue if err != nil && gCol.FieldType.IsArray() { return nil, completeError(tbl, gCol.Offset, rowIdx, err) } - if e.Ctx().GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) != nil { + if e.Ctx().GetSessionVars().StmtCtx.HandleTruncate(err) != nil { return nil, err } row[colIdx], err = table.CastValue(e.Ctx(), val, gCol.ToInfo(), false, false) diff --git a/pkg/expression/aggregation/avg.go b/pkg/expression/aggregation/avg.go index 3fa1911f9e3a4..ff5aea5739316 100644 --- a/pkg/expression/aggregation/avg.go +++ b/pkg/expression/aggregation/avg.go @@ -60,9 +60,9 @@ func (af *avgFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEv func (af *avgFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) (err error) { switch af.Mode { case Partial1Mode, CompleteMode: - err = af.updateSum(sc.TypeCtx, evalCtx, row) + err = af.updateSum(sc.TypeCtx(), evalCtx, row) case Partial2Mode, FinalMode: - err = af.updateAvg(sc.TypeCtx, evalCtx, row) + err = af.updateAvg(sc.TypeCtx(), evalCtx, row) case DedupMode: panic("DedupMode is not supported now.") } diff --git a/pkg/expression/aggregation/sum.go b/pkg/expression/aggregation/sum.go index 5169682cc3bbf..74b5e8bdf72a0 100644 --- a/pkg/expression/aggregation/sum.go +++ b/pkg/expression/aggregation/sum.go @@ -26,7 +26,7 @@ type sumFunction struct { // Update implements Aggregation interface. func (sf *sumFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) error { - return sf.updateSum(sc.TypeCtx, evalCtx, row) + return sf.updateSum(sc.TypeCtx(), evalCtx, row) } // GetResult implements Aggregation interface. diff --git a/pkg/expression/builtin_arithmetic.go b/pkg/expression/builtin_arithmetic.go index 6f6a84e5ea9ef..1bff34ccd91a1 100644 --- a/pkg/expression/builtin_arithmetic.go +++ b/pkg/expression/builtin_arithmetic.go @@ -727,7 +727,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row chunk.Row) (*types.M return c, true, handleDivisionByZeroError(s.ctx) } else if err == types.ErrTruncated { sc := s.ctx.GetSessionVars().StmtCtx - err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) + err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) } else if err == nil { _, frac := c.PrecisionAndFrac() if frac < s.baseBuiltinFunc.tp.GetDecimal() { @@ -846,7 +846,7 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64 return 0, true, handleDivisionByZeroError(s.ctx) } if err == types.ErrTruncated { - err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) + err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) } if err == types.ErrOverflow { newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c) diff --git a/pkg/expression/builtin_arithmetic_vec.go b/pkg/expression/builtin_arithmetic_vec.go index c4951c581bc38..49c62896ecf6d 100644 --- a/pkg/expression/builtin_arithmetic_vec.go +++ b/pkg/expression/builtin_arithmetic_vec.go @@ -95,7 +95,7 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(input *chunk.Chunk, r result.SetNull(i, true) continue } else if err == types.ErrTruncated { - if err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil { + if err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil { return err } } else if err == nil { @@ -617,7 +617,7 @@ func (b *builtinArithmeticIntDivideDecimalSig) vecEvalInt(input *chunk.Chunk, re continue } if err == types.ErrTruncated { - err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) + err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)) } else if err == types.ErrOverflow { newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c) err = sc.HandleOverflow(newErr, newErr) diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 01bd3438ef417..b9f54d441a58d 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -534,7 +534,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ if item.TypeCode != types.JSONTypeCodeString { return nil, ErrInvalidJSONForFuncIndex } - return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc.TypeCtx, false) + return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc.TypeCtx(), false) } case types.ETInt: return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { @@ -552,7 +552,7 @@ func convertJSON2Tp(evalType types.EvalType) func(*stmtctx.StatementContext, typ if item.TypeCode != types.JSONTypeCodeFloat64 && item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 { return nil, ErrInvalidJSONForFuncIndex } - return types.ConvertJSONToFloat(sc.TypeCtx, item) + return types.ConvertJSONToFloat(sc.TypeCtx(), item) } case types.ETDatetime: return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { @@ -730,7 +730,7 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul if tp.GetType() == mysql.TypeYear && res == "0" { res = "0000" } - res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false) if err != nil { return res, false, err } @@ -790,7 +790,7 @@ func (b *builtinCastIntAsDurationSig) evalDuration(row chunk.Row) (res types.Dur err = b.ctx.GetSessionVars().StmtCtx.HandleOverflow(err, err) } if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) } return res, true, err } @@ -1045,7 +1045,7 @@ func (b *builtinCastRealAsStringSig) evalString(row chunk.Row) (res string, isNu // If we strconv.FormatFloat the value with 64bits, the result is incorrect! bits = 32 } - res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false) if err != nil { return res, false, err } @@ -1102,7 +1102,7 @@ func (b *builtinCastRealAsDurationSig) evalDuration(row chunk.Row) (res types.Du res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(val, 'f', -1, 64), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) // ErrTruncatedWrongVal needs to be considered NULL. return res, true, err } @@ -1191,7 +1191,7 @@ func (b *builtinCastDecimalAsStringSig) evalString(row chunk.Row) (res string, i return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc.TypeCtx(), false) if err != nil { return res, false, err } @@ -1279,7 +1279,7 @@ func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types } res, _, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) // ErrTruncatedWrongVal needs to be considered NULL. return res, true, err } @@ -1301,7 +1301,7 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is if isNull || err != nil { return res, isNull, err } - res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false) if err != nil { return res, false, err } @@ -1366,7 +1366,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo var ures uint64 sc := b.ctx.GetSessionVars().StmtCtx if !isNegative { - ures, err = types.StrToUint(sc.TypeCtx, val, true) + ures, err = types.StrToUint(sc.TypeCtx(), val, true) res = int64(ures) if err == nil && !mysql.HasUnsignedFlag(b.tp.GetFlag()) && ures > uint64(math.MaxInt64) { @@ -1375,7 +1375,7 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo } else if b.inUnion && mysql.HasUnsignedFlag(b.tp.GetFlag()) { res = 0 } else { - res, err = types.StrToInt(sc.TypeCtx, val, true) + res, err = types.StrToInt(sc.TypeCtx(), val, true) if err == nil && mysql.HasUnsignedFlag(b.tp.GetFlag()) { // If overflow, don't append this warnings sc.AppendWarning(types.ErrCastNegIntAsUnsigned) @@ -1411,7 +1411,7 @@ func (b *builtinCastStringAsRealSig) evalReal(row chunk.Row) (res float64, isNul return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.StrToFloat(sc.TypeCtx, val, true) + res, err = types.StrToFloat(sc.TypeCtx(), val, true) if err != nil { return 0, false, err } @@ -1449,7 +1449,7 @@ func (b *builtinCastStringAsDecimalSig) evalDecimal(row chunk.Row) (res *types.M if err == types.ErrTruncated { err = types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", []byte(val)) } - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) if err != nil { return res, false, err } @@ -1506,7 +1506,7 @@ func (b *builtinCastStringAsDurationSig) evalDuration(row chunk.Row) (res types. res, isNull, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, val, b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { sc := b.ctx.GetSessionVars().StmtCtx - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } return res, isNull, err } @@ -1619,7 +1619,7 @@ func (b *builtinCastTimeAsStringSig) evalString(row chunk.Row) (res string, isNu return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx(), false) if err != nil { return res, false, err } @@ -1752,7 +1752,7 @@ func (b *builtinCastDurationAsStringSig) evalString(row chunk.Row) (res string, return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc.TypeCtx(), false) if err != nil { return res, false, err } @@ -1854,7 +1854,7 @@ func (b *builtinCastJSONAsRealSig) evalReal(row chunk.Row) (res float64, isNull return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ConvertJSONToFloat(sc.TypeCtx, val) + res, err = types.ConvertJSONToFloat(sc.TypeCtx(), val) return } @@ -1874,7 +1874,7 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD return res, isNull, err } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ConvertJSONToDecimal(sc.TypeCtx, val) + res, err = types.ConvertJSONToDecimal(sc.TypeCtx(), val) if err != nil { return res, false, err } @@ -1897,7 +1897,7 @@ func (b *builtinCastJSONAsStringSig) evalString(row chunk.Row) (res string, isNu if isNull || err != nil { return res, isNull, err } - s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) + s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false) if err != nil { return res, false, err } @@ -1960,7 +1960,7 @@ func (b *builtinCastJSONAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNu return res, isNull, err default: err = types.ErrTruncatedWrongVal.GenWithStackByArgs(types.TypeStr(b.tp.GetType()), val.String()) - return res, true, b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + return res, true, b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) } } @@ -2002,12 +2002,12 @@ func (b *builtinCastJSONAsDurationSig) evalDuration(row chunk.Row) (res types.Du res, _, err = types.ParseDuration(stmtCtx, s, b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { sc := b.ctx.GetSessionVars().StmtCtx - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } return res, isNull, err default: err = types.ErrTruncatedWrongVal.GenWithStackByArgs("TIME", val.String()) - return res, true, stmtCtx.TypeCtx.HandleTruncate(err) + return res, true, stmtCtx.HandleTruncate(err) } } diff --git a/pkg/expression/builtin_cast_vec.go b/pkg/expression/builtin_cast_vec.go index 4a9a1ba079d1d..5b8f0f682bf41 100644 --- a/pkg/expression/builtin_cast_vec.go +++ b/pkg/expression/builtin_cast_vec.go @@ -51,7 +51,7 @@ func (b *builtinCastIntAsDurationSig) vecEvalDuration(input *chunk.Chunk, result err = b.ctx.GetSessionVars().StmtCtx.HandleOverflow(err, err) } if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) } if err != nil { return err @@ -223,7 +223,7 @@ func (b *builtinCastRealAsStringSig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(v, 'f', -1, bits), b.tp, sc.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(v, 'f', -1, bits), b.tp, sc.TypeCtx(), false) if err != nil { return err } @@ -263,7 +263,7 @@ func (b *builtinCastDecimalAsStringSig) vecEvalString(input *chunk.Chunk, result result.AppendNull() continue } - res, e := types.ProduceStrWithSpecifiedTp(string(v.ToString()), b.tp, sc.TypeCtx, false) + res, e := types.ProduceStrWithSpecifiedTp(string(v.ToString()), b.tp, sc.TypeCtx(), false) if e != nil { return e } @@ -456,7 +456,7 @@ func (b *builtinCastJSONAsRealSig) vecEvalReal(input *chunk.Chunk, result *chunk if result.IsNull(i) { continue } - f64s[i], err = types.ConvertJSONToFloat(sc.TypeCtx, buf.GetJSON(i)) + f64s[i], err = types.ConvertJSONToFloat(sc.TypeCtx(), buf.GetJSON(i)) if err != nil { return err } @@ -716,7 +716,7 @@ func (b *builtinCastIntAsStringSig) vecEvalString(input *chunk.Chunk, result *ch if isYearType && str == "0" { str = "0000" } - str, err = types.ProduceStrWithSpecifiedTp(str, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) + str, err = types.ProduceStrWithSpecifiedTp(str, b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false) if err != nil { return err } @@ -951,7 +951,7 @@ func (b *builtinCastStringAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk return err } result.MergeNulls(buf) - typeCtx := b.ctx.GetSessionVars().StmtCtx.TypeCtx + typeCtx := b.ctx.GetSessionVars().StmtCtx.TypeCtx() i64s := result.Int64s() isUnsigned := mysql.HasUnsignedFlag(b.tp.GetFlag()) unionUnsigned := isUnsigned && b.inUnion @@ -1013,7 +1013,7 @@ func (b *builtinCastStringAsDurationSig) vecEvalDuration(input *chunk.Chunk, res dur, isNull, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, buf.GetString(i), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) } if err != nil { return err @@ -1187,7 +1187,7 @@ func (b *builtinCastJSONAsStringSig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - s, err := types.ProduceStrWithSpecifiedTp(buf.GetJSON(i).String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx, false) + s, err := types.ProduceStrWithSpecifiedTp(buf.GetJSON(i).String(), b.tp, b.ctx.GetSessionVars().StmtCtx.TypeCtx(), false) if err != nil { return err } @@ -1291,7 +1291,7 @@ func (b *builtinCastRealAsDurationSig) vecEvalDuration(input *chunk.Chunk, resul dur, _, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, strconv.FormatFloat(f64s[i], 'f', -1, 64), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) if err != nil { return err } @@ -1395,7 +1395,7 @@ func (b *builtinCastDurationAsStringSig) vecEvalString(input *chunk.Chunk, resul result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(buf.GetDuration(i, fsp).String(), b.tp, sc.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(buf.GetDuration(i, fsp).String(), b.tp, sc.TypeCtx(), false) if err != nil { return err } @@ -1600,7 +1600,7 @@ func (b *builtinCastTimeAsStringSig) vecEvalString(input *chunk.Chunk, result *c result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(v.String(), b.tp, sc.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(v.String(), b.tp, sc.TypeCtx(), false) if err != nil { return err } @@ -1639,7 +1639,7 @@ func (b *builtinCastJSONAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result if result.IsNull(i) { continue } - tempres, err := types.ConvertJSONToDecimal(sc.TypeCtx, buf.GetJSON(i)) + tempres, err := types.ConvertJSONToDecimal(sc.TypeCtx(), buf.GetJSON(i)) if err != nil { return err } @@ -1685,7 +1685,7 @@ func (b *builtinCastStringAsRealSig) vecEvalReal(input *chunk.Chunk, result *chu if result.IsNull(i) { continue } - res, err := types.StrToFloat(sc.TypeCtx, buf.GetString(i), true) + res, err := types.StrToFloat(sc.TypeCtx(), buf.GetString(i), true) if err != nil { return err } @@ -1730,7 +1730,7 @@ func (b *builtinCastStringAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, resul isNegative := len(val) > 0 && val[0] == '-' dec := new(types.MyDecimal) if !(b.inUnion && mysql.HasUnsignedFlag(b.tp.GetFlag()) && isNegative) { - if err := stmtCtx.TypeCtx.HandleTruncate(dec.FromString([]byte(val))); err != nil { + if err := stmtCtx.HandleTruncate(dec.FromString([]byte(val))); err != nil { return err } dec, err := types.ProduceDecWithSpecifiedTp(dec, b.tp, stmtCtx) @@ -1874,7 +1874,7 @@ func (b *builtinCastDecimalAsDurationSig) vecEvalDuration(input *chunk.Chunk, re dur, _, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(args[i].ToString()), b.tp.GetDecimal()) if err != nil { if types.ErrTruncatedWrongVal.Equal(err) { - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(err) if err != nil { return err } @@ -1913,7 +1913,7 @@ func (b *builtinCastStringAsStringSig) vecEvalString(input *chunk.Chunk, result result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(buf.GetString(i), b.tp, sc.TypeCtx, false) + res, err = types.ProduceStrWithSpecifiedTp(buf.GetString(i), b.tp, sc.TypeCtx(), false) if err != nil { return err } @@ -1979,7 +1979,7 @@ func (b *builtinCastJSONAsDurationSig) vecEvalDuration(input *chunk.Chunk, resul } dur, _, err = types.ParseDuration(stmtCtx, s, b.tp.GetDecimal()) if types.ErrTruncatedWrongVal.Equal(err) { - err = stmtCtx.TypeCtx.HandleTruncate(err) + err = stmtCtx.HandleTruncate(err) } if err != nil { return err @@ -1987,7 +1987,7 @@ func (b *builtinCastJSONAsDurationSig) vecEvalDuration(input *chunk.Chunk, resul ds[i] = dur.Duration default: err = types.ErrTruncatedWrongVal.GenWithStackByArgs(types.TypeStr(b.tp.GetType()), val.String()) - err = stmtCtx.TypeCtx.HandleTruncate(err) + err = stmtCtx.HandleTruncate(err) if err != nil { return err } diff --git a/pkg/expression/builtin_other.go b/pkg/expression/builtin_other.go index aa8b40bc34a6f..b23611acf8f2e 100644 --- a/pkg/expression/builtin_other.go +++ b/pkg/expression/builtin_other.go @@ -1210,7 +1210,7 @@ func (b *builtinValuesIntSig) evalInt(_ chunk.Row) (int64, bool, error) { } if len(val) < 8 { var binary types.BinaryLiteral = val - v, err := binary.ToInt(b.ctx.GetSessionVars().StmtCtx.TypeCtx) + v, err := binary.ToInt(b.ctx.GetSessionVars().StmtCtx.TypeCtx()) if err != nil { return 0, true, errors.Trace(err) } diff --git a/pkg/expression/builtin_time.go b/pkg/expression/builtin_time.go index f9f9940fbec16..5af1a08e1be25 100644 --- a/pkg/expression/builtin_time.go +++ b/pkg/expression/builtin_time.go @@ -600,7 +600,7 @@ func calculateTimeDiff(sc *stmtctx.StatementContext, lhs, rhs types.Time) (d typ d = lhs.Sub(sc, &rhs) d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) if types.ErrTruncatedWrongVal.Equal(err) { - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } return d, err != nil, err } @@ -615,7 +615,7 @@ func calculateDurationTimeDiff(ctx sessionctx.Context, lhs, rhs types.Duration) d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) if types.ErrTruncatedWrongVal.Equal(err) { sc := ctx.GetSessionVars().StmtCtx - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } return d, err != nil, err } @@ -2275,7 +2275,7 @@ func (b *builtinTimeSig) evalDuration(row chunk.Row) (res types.Duration, isNull sc := b.ctx.GetSessionVars().StmtCtx res, _, err = types.ParseDuration(sc, expr, fsp) if types.ErrTruncatedWrongVal.Equal(err) { - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } return res, isNull, err } @@ -5570,7 +5570,7 @@ func (b *builtinSecToTimeSig) evalDuration(row chunk.Row) (types.Duration, bool, minute = 59 second = 59 demical = 0 - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) if err != nil { return types.Duration{}, err != nil, err } diff --git a/pkg/expression/builtin_time_vec.go b/pkg/expression/builtin_time_vec.go index 851f2b2b2de2e..0ff70bf96e7c6 100644 --- a/pkg/expression/builtin_time_vec.go +++ b/pkg/expression/builtin_time_vec.go @@ -1928,7 +1928,7 @@ func (b *builtinSecToTimeSig) vecEvalDuration(input *chunk.Chunk, result *chunk. minute = 59 second = 59 demical = 0 - err = b.ctx.GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) + err = b.ctx.GetSessionVars().StmtCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) if err != nil { return err } @@ -2411,7 +2411,7 @@ func (b *builtinTimeSig) vecEvalDuration(input *chunk.Chunk, result *chunk.Colum res, _, err := types.ParseDuration(sc, expr, fsp) if types.ErrTruncatedWrongVal.Equal(err) { - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } if err != nil { return err diff --git a/pkg/expression/column.go b/pkg/expression/column.go index 437d081e10152..e7e8af727a6c1 100644 --- a/pkg/expression/column.go +++ b/pkg/expression/column.go @@ -422,7 +422,7 @@ func (col *Column) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, return 0, true, nil } if val.Kind() == types.KindMysqlBit { - val, err := val.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) + val, err := val.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) return int64(val), err != nil, err } res, err := val.ToInt64(ctx.GetSessionVars().StmtCtx) diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index c9c7b8197d5de..74f2504f8beaf 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -278,13 +278,13 @@ func (c *Constant) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, if c.GetType().GetType() == mysql.TypeNull || dt.IsNull() { return 0, true, nil } else if dt.Kind() == types.KindBinaryLiteral { - val, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) + val, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) return int64(val), err != nil, err } else if c.GetType().Hybrid() || dt.Kind() == types.KindString { res, err := dt.ToInt64(ctx.GetSessionVars().StmtCtx) return res, false, err } else if dt.Kind() == types.KindMysqlBit { - uintVal, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx) + uintVal, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) return int64(uintVal), false, err } return dt.GetInt64(), false, nil @@ -303,7 +303,7 @@ func (c *Constant) EvalReal(ctx sessionctx.Context, row chunk.Row) (float64, boo return 0, true, nil } if c.GetType().Hybrid() || dt.Kind() == types.KindBinaryLiteral || dt.Kind() == types.KindString { - res, err := dt.ToFloat64(ctx.GetSessionVars().StmtCtx.TypeCtx) + res, err := dt.ToFloat64(ctx.GetSessionVars().StmtCtx.TypeCtx()) return res, false, err } return dt.GetFloat64(), false, nil @@ -337,7 +337,7 @@ func (c *Constant) EvalDecimal(ctx sessionctx.Context, row chunk.Row) (*types.My if c.GetType().GetType() == mysql.TypeNull || dt.IsNull() { return nil, true, nil } - res, err := dt.ToDecimal(ctx.GetSessionVars().StmtCtx.TypeCtx) + res, err := dt.ToDecimal(ctx.GetSessionVars().StmtCtx.TypeCtx()) if err != nil { return nil, false, err } diff --git a/pkg/expression/constant_fold.go b/pkg/expression/constant_fold.go index 691c8d50a29d6..4b391e00258ad 100644 --- a/pkg/expression/constant_fold.go +++ b/pkg/expression/constant_fold.go @@ -208,7 +208,7 @@ func foldConstant(expr Expression) (Expression, bool) { // of Constant to nil is ok. return &Constant{Value: value, RetType: x.RetType}, false } - if isTrue, err := value.ToBool(sc.TypeCtx); err == nil && isTrue == 0 { + if isTrue, err := value.ToBool(sc.TypeCtx()); err == nil && isTrue == 0 { // This Constant is created to compose the result expression of EvaluateExprWithNull when InNullRejectCheck // is true. We just check whether the result expression is null or false and then let it die. Basically, // the constant is used once briefly and will not be retained for a long time. Hence setting DeferredExpr diff --git a/pkg/expression/errors.go b/pkg/expression/errors.go index 9cdf0a582ec42..a259164682c59 100644 --- a/pkg/expression/errors.go +++ b/pkg/expression/errors.go @@ -76,7 +76,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { return err } sc := ctx.GetSessionVars().StmtCtx - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { return err } diff --git a/pkg/expression/evaluator_test.go b/pkg/expression/evaluator_test.go index 9c727ba10f616..55c10319ab636 100644 --- a/pkg/expression/evaluator_test.go +++ b/pkg/expression/evaluator_test.go @@ -206,7 +206,7 @@ func TestBinopComparison(t *testing.T) { require.NoError(t, err) v, err := evalBuiltinFunc(f, chunk.Row{}) require.NoError(t, err) - val, err := v.ToBool(ctx.GetSessionVars().StmtCtx.TypeCtx) + val, err := v.ToBool(ctx.GetSessionVars().StmtCtx.TypeCtx()) require.NoError(t, err) require.Equal(t, tt.result, val) } @@ -407,10 +407,10 @@ func TestBinopNumeric(t *testing.T) { default: // we use float64 as the result type check for all. sc := ctx.GetSessionVars().StmtCtx - f, err := v.ToFloat64(sc.TypeCtx) + f, err := v.ToFloat64(sc.TypeCtx()) require.NoError(t, err) d := types.NewDatum(tt.ret) - r, err := d.ToFloat64(sc.TypeCtx) + r, err := d.ToFloat64(sc.TypeCtx()) require.NoError(t, err) require.Equal(t, r, f) } diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index f81baa6d875ae..4f0bb5843e68f 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -274,7 +274,7 @@ func EvalBool(ctx sessionctx.Context, exprList CNFExprs, row chunk.Row) (bool, b continue } - i, err := data.ToBool(ctx.GetSessionVars().StmtCtx.TypeCtx) + i, err := data.ToBool(ctx.GetSessionVars().StmtCtx.TypeCtx()) if err != nil { i, err = HandleOverflowOnSelection(ctx.GetSessionVars().StmtCtx, i, err) if err != nil { @@ -494,14 +494,14 @@ func toBool(sc *stmtctx.StatementContext, tp *types.FieldType, eType types.EvalT } case mysql.TypeBit: var bl types.BinaryLiteral = buf.GetBytes(i) - iVal, err := bl.ToInt(sc.TypeCtx) + iVal, err := bl.ToInt(sc.TypeCtx()) if err != nil { return err } fVal = float64(iVal) } } else { - fVal, err = types.StrToFloat(sc.TypeCtx, sVal, false) + fVal, err = types.StrToFloat(sc.TypeCtx(), sVal, false) if err != nil { return err } diff --git a/pkg/expression/helper.go b/pkg/expression/helper.go index 6308608794319..d0fce4ffb2597 100644 --- a/pkg/expression/helper.go +++ b/pkg/expression/helper.go @@ -182,7 +182,7 @@ func getStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { return now, err } - timestamp, err := types.StrToFloat(sessionVars.StmtCtx.TypeCtx, timestampStr, false) + timestamp, err := types.StrToFloat(sessionVars.StmtCtx.TypeCtx(), timestampStr, false) if err != nil { return time.Time{}, err } diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 31c5d84117a78..a349670be4e65 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -419,7 +419,7 @@ func (sf *ScalarFunction) Eval(row chunk.Row) (d types.Datum, err error) { res, err = types.ParseEnum(tp.GetElems(), str, tp.GetCollate()) if ctx := sf.GetCtx(); ctx != nil { if sc := ctx.GetSessionVars().StmtCtx; sc != nil { - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } } } else { diff --git a/pkg/planner/cardinality/selectivity.go b/pkg/planner/cardinality/selectivity.go index 4fca0886eb9c9..02712229441e3 100644 --- a/pkg/planner/cardinality/selectivity.go +++ b/pkg/planner/cardinality/selectivity.go @@ -277,7 +277,7 @@ func Selectivity( ret *= 0 mask &^= 1 << uint64(i) delete(notCoveredConstants, i) - } else if isTrue, err := c.Value.ToBool(sc.TypeCtx); err == nil { + } else if isTrue, err := c.Value.ToBool(sc.TypeCtx()); err == nil { if isTrue == 0 { // c is false ret *= 0 diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index ff6ddec9e4c42..ae77f237c13ae 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -2444,7 +2444,7 @@ func getUintFromNode(ctx sessionctx.Context, n ast.Node, mustInt64orUint64 bool) return uint64(v), false, true } case string: - ctx := ctx.GetSessionVars().StmtCtx.TypeCtx + ctx := ctx.GetSessionVars().StmtCtx.TypeCtx() uVal, err := types.StrToUint(ctx, v, false) if err != nil { return 0, false, false diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index c7ad4e5f42ebf..1765f3bbbb758 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -373,7 +373,7 @@ func (p *LogicalJoin) extractFDForOuterJoin(filtersFromApply []expression.Expres // if one of the inner condition is constant false, the inner side are all null, left make constant all of that. for _, one := range innerCondition { if c, ok := one.(*expression.Constant); ok && c.DeferredExpr == nil && c.ParamMarker == nil { - if isTrue, err := c.Value.ToBool(p.SCtx().GetSessionVars().StmtCtx.TypeCtx); err == nil { + if isTrue, err := c.Value.ToBool(p.SCtx().GetSessionVars().StmtCtx.TypeCtx()); err == nil { if isTrue == 0 { // c is false opt.InnerIsFalse = true diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index 99367e8faaf71..661328691cde7 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -518,7 +518,7 @@ func newListPartitionPruner(ctx sessionctx.Context, tbl table.Table, partitionNa func (l *listPartitionPruner) locatePartition(cond expression.Expression) (tables.ListPartitionLocation, bool, error) { switch sf := cond.(type) { case *expression.Constant: - b, err := sf.Value.ToBool(l.ctx.GetSessionVars().StmtCtx.TypeCtx) + b, err := sf.Value.ToBool(l.ctx.GetSessionVars().StmtCtx.TypeCtx()) if err == nil && b == 0 { // A constant false expression. return nil, false, nil @@ -1297,7 +1297,7 @@ type rangePruner struct { func (p *rangePruner) partitionRangeForExpr(sctx sessionctx.Context, expr expression.Expression) (start int, end int, ok bool) { if constExpr, ok := expr.(*expression.Constant); ok { - if b, err := constExpr.Value.ToBool(sctx.GetSessionVars().StmtCtx.TypeCtx); err == nil && b == 0 { + if b, err := constExpr.Value.ToBool(sctx.GetSessionVars().StmtCtx.TypeCtx()); err == nil && b == 0 { // A constant false expression. return 0, 0, true } diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index bbbf233a3efcc..d5fbaab074038 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -729,7 +729,7 @@ func DeleteTrueExprs(p LogicalPlan, conds []expression.Expression) []expression. continue } sc := p.SCtx().GetSessionVars().StmtCtx - if isTrue, err := con.Value.ToBool(sc.TypeCtx); err == nil && isTrue == 1 { + if isTrue, err := con.Value.ToBool(sc.TypeCtx()); err == nil && isTrue == 1 { continue } newConds = append(newConds, cond) diff --git a/pkg/server/internal/parse/parse.go b/pkg/server/internal/parse/parse.go index 1b132a9d0d2eb..e55ca68eb6656 100644 --- a/pkg/server/internal/parse/parse.go +++ b/pkg/server/internal/parse/parse.go @@ -246,7 +246,7 @@ func ExecArgs(sc *stmtctx.StatementContext, params []expression.Expression, boun args[i] = types.NewDecimalDatum(nil) } else { var dec types.MyDecimal - err = sc.TypeCtx.HandleTruncate(dec.FromString(v)) + err = sc.HandleTruncate(dec.FromString(v)) if err != nil { return err } diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 2f28d0a3fe3bc..8aa3a0e933996 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -155,8 +155,8 @@ type StatementContext struct { _ constructor.Constructor `ctor:"NewStmtCtx,NewStmtCtxWithTimeZone,Reset"` - // TypeCtx is used to indicate how make the type conversation. - TypeCtx typectx.Context + // typeCtx is used to indicate how to make the type conversation. + typeCtx typectx.Context // Set the following variables before execution StmtHints @@ -428,7 +428,7 @@ type StatementContext struct { // NewStmtCtx creates a new statement context func NewStmtCtx() *StatementContext { sc := &StatementContext{} - sc.TypeCtx = typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning) + sc.typeCtx = typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning) return sc } @@ -436,42 +436,52 @@ func NewStmtCtx() *StatementContext { func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { intest.Assert(tz) sc := &StatementContext{} - sc.TypeCtx = typectx.NewContext(typectx.StrictFlags, tz, sc.AppendWarning) + sc.typeCtx = typectx.NewContext(typectx.StrictFlags, tz, sc.AppendWarning) return sc } // Reset resets a statement context func (sc *StatementContext) Reset() { *sc = StatementContext{ - TypeCtx: typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning), + typeCtx: typectx.NewContext(typectx.StrictFlags, time.UTC, sc.AppendWarning), } } // TimeZone returns the timezone of the type context func (sc *StatementContext) TimeZone() *time.Location { - return sc.TypeCtx.Location() + return sc.typeCtx.Location() } // SetTimeZone sets the timezone func (sc *StatementContext) SetTimeZone(tz *time.Location) { intest.Assert(tz) - sc.TypeCtx = sc.TypeCtx.WithLocation(tz) + sc.typeCtx = sc.typeCtx.WithLocation(tz) +} + +// TypeCtx returns the type context +func (sc *StatementContext) TypeCtx() typectx.Context { + return sc.typeCtx } // TypeFlags returns the type flags func (sc *StatementContext) TypeFlags() typectx.Flags { - return sc.TypeCtx.Flags() + return sc.typeCtx.Flags() } // SetTypeFlags sets the type flags func (sc *StatementContext) SetTypeFlags(flags typectx.Flags) { - sc.TypeCtx = sc.TypeCtx.WithFlags(flags) + sc.typeCtx = sc.typeCtx.WithFlags(flags) } // UpdateTypeFlags updates the flags of the type context func (sc *StatementContext) UpdateTypeFlags(fn func(typectx.Flags) typectx.Flags) { - flags := fn(sc.TypeCtx.Flags()) - sc.TypeCtx = sc.TypeCtx.WithFlags(flags) + flags := fn(sc.typeCtx.Flags()) + sc.typeCtx = sc.typeCtx.WithFlags(flags) +} + +// HandleTruncate ignores or returns the error based on the TypeContext inside. +func (sc *StatementContext) HandleTruncate(err error) error { + return sc.typeCtx.HandleTruncate(err) } // StmtHints are SessionVars related sql hints. @@ -1134,7 +1144,7 @@ func (sc *StatementContext) ShouldClipToZero() bool { // so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types. func (sc *StatementContext) ShouldIgnoreOverflowError() bool { // TODO: move this function into `/types` pkg - if (sc.InInsertStmt && sc.TypeCtx.Flags().TruncateAsWarning()) || sc.InLoadDataStmt { + if (sc.InInsertStmt && sc.TypeFlags().TruncateAsWarning()) || sc.InLoadDataStmt { return true } return false @@ -1150,9 +1160,9 @@ func (sc *StatementContext) PushDownFlags() uint64 { } else if sc.InSelectStmt { flags |= model.FlagInSelectStmt } - if sc.TypeCtx.Flags().IgnoreTruncateErr() { + if sc.TypeFlags().IgnoreTruncateErr() { flags |= model.FlagIgnoreTruncate - } else if sc.TypeCtx.Flags().TruncateAsWarning() { + } else if sc.TypeFlags().TruncateAsWarning() { flags |= model.FlagTruncateAsWarning } if sc.OverflowAsWarning { @@ -1227,11 +1237,11 @@ func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location) sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0 sc.SetTimeZone(tz) - typeFlags := sc.TypeCtx.Flags() + typeFlags := sc.TypeFlags() typeFlags = typeFlags. WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0). WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0) - sc.TypeCtx = typectx.NewContext(typeFlags, tz, sc.AppendWarning) + sc.typeCtx = typectx.NewContext(typeFlags, tz, sc.AppendWarning) } // GetLockWaitStartTime returns the statement pessimistic lock wait start time @@ -1376,7 +1386,7 @@ func (sc *StatementContext) RecordedStatsLoadStatusCnt() (cnt int) { // little as possible. func (sc *StatementContext) TypeCtxOrDefault() typectx.Context { if sc != nil { - return sc.TypeCtx + return sc.typeCtx } return typectx.DefaultNoWarningContext diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index 1ae5c513ea7a8..09b4ed41c2263 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -316,10 +316,10 @@ func TestStmtHintsClone(t *testing.T) { func TestNewStmtCtx(t *testing.T) { sc := stmtctx.NewStmtCtx() - require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) - require.Same(t, time.UTC, sc.TypeCtx.Location()) + require.Equal(t, types.StrictFlags, sc.TypeFlags()) + require.Same(t, time.UTC, sc.TimeZone()) require.Same(t, time.UTC, sc.TimeZone()) - sc.TypeCtx.AppendWarning(errors.New("err1")) + sc.AppendWarning(errors.New("err1")) warnings := sc.GetWarnings() require.Equal(t, 1, len(warnings)) require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) @@ -327,10 +327,10 @@ func TestNewStmtCtx(t *testing.T) { tz := time.FixedZone("UTC+1", 2*60*60) sc = stmtctx.NewStmtCtxWithTimeZone(tz) - require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) - require.Same(t, tz, sc.TypeCtx.Location()) + require.Equal(t, types.StrictFlags, sc.TypeFlags()) + require.Same(t, tz, sc.TimeZone()) require.Same(t, tz, sc.TimeZone()) - sc.TypeCtx.AppendWarning(errors.New("err2")) + sc.AppendWarning(errors.New("err2")) warnings = sc.GetWarnings() require.Equal(t, 1, len(warnings)) require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) @@ -339,34 +339,34 @@ func TestNewStmtCtx(t *testing.T) { func TestSetStmtCtxTimeZone(t *testing.T) { sc := stmtctx.NewStmtCtx() - require.Same(t, time.UTC, sc.TypeCtx.Location()) + require.Same(t, time.UTC, sc.TimeZone()) tz := time.FixedZone("UTC+1", 2*60*60) sc.SetTimeZone(tz) - require.Same(t, tz, sc.TypeCtx.Location()) + require.Same(t, tz, sc.TimeZone()) } func TestSetStmtCtxTypeFlags(t *testing.T) { sc := stmtctx.NewStmtCtx() - require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) + require.Equal(t, types.StrictFlags, sc.TypeFlags()) sc.SetTypeFlags(typectx.FlagClipNegativeToZero | typectx.FlagSkipASCIICheck) require.Equal(t, typectx.FlagClipNegativeToZero|typectx.FlagSkipASCIICheck, sc.TypeFlags()) - require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags()) + require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagInvalidDateAsWarning) require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagInvalidDateAsWarning, sc.TypeFlags()) - require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags()) + require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) sc.UpdateTypeFlags(func(flags typectx.Flags) typectx.Flags { return (flags | typectx.FlagSkipUTF8Check | typectx.FlagClipNegativeToZero) &^ typectx.FlagSkipASCIICheck }) require.Equal(t, typectx.FlagSkipUTF8Check|typectx.FlagClipNegativeToZero|typectx.FlagInvalidDateAsWarning, sc.TypeFlags()) - require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags()) + require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) } func TestResetStmtCtx(t *testing.T) { sc := stmtctx.NewStmtCtx() - require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) + require.Equal(t, types.StrictFlags, sc.TypeFlags()) tz := time.FixedZone("UTC+1", 2*60*60) sc.SetTimeZone(tz) @@ -381,9 +381,9 @@ func TestResetStmtCtx(t *testing.T) { sc.Reset() require.Same(t, time.UTC, sc.TimeZone()) - require.Same(t, time.UTC, sc.TypeCtx.Location()) + require.Same(t, time.UTC, sc.TimeZone()) + require.Equal(t, types.StrictFlags, sc.TypeFlags()) require.Equal(t, types.StrictFlags, sc.TypeFlags()) - require.Equal(t, types.StrictFlags, sc.TypeCtx.Flags()) require.False(t, sc.InRestrictedSQL) require.Empty(t, sc.StmtType) require.Equal(t, 0, len(sc.GetWarnings())) diff --git a/pkg/store/mockstore/mockcopr/executor.go b/pkg/store/mockstore/mockcopr/executor.go index 78d8687a17355..70e9d5159632d 100644 --- a/pkg/store/mockstore/mockcopr/executor.go +++ b/pkg/store/mockstore/mockcopr/executor.go @@ -414,7 +414,7 @@ func evalBool(exprs []expression.Expression, row []types.Datum, ctx *stmtctx.Sta return false, nil } - isBool, err := data.ToBool(ctx.TypeCtx) + isBool, err := data.ToBool(ctx.TypeCtx()) isBool, err = expression.HandleOverflowOnSelection(ctx, isBool, err) if err != nil { return false, errors.Trace(err) diff --git a/pkg/store/mockstore/unistore/cophandler/closure_exec.go b/pkg/store/mockstore/unistore/cophandler/closure_exec.go index 94eda32c085ff..577adb3ffb486 100644 --- a/pkg/store/mockstore/unistore/cophandler/closure_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/closure_exec.go @@ -799,7 +799,7 @@ func (e *closureExecutor) processSelection(needCollectDetail bool) (gotRow bool, if d.IsNull() { gotRow = false } else { - isTrue, err := d.ToBool(e.sc.TypeCtx) + isTrue, err := d.ToBool(e.sc.TypeCtx()) isTrue, err = expression.HandleOverflowOnSelection(e.sc, isTrue, err) if err != nil { return false, errors.Trace(err) diff --git a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go index 12fc125b295bd..c86b690298e7f 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go @@ -1133,7 +1133,7 @@ func (e *selExec) next() (*chunk.Chunk, error) { if d.IsNull() { passCheck = false } else { - isBool, err := d.ToBool(e.sc.TypeCtx) + isBool, err := d.ToBool(e.sc.TypeCtx()) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/table/column.go b/pkg/table/column.go index d021732152b96..15be521014ca8 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -322,7 +322,7 @@ func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, r zap.Uint64("conn", ctx.GetSessionVars().ConnectionID), zap.Error(err)) } - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) err = sc.HandleOverflow(err, err) if forceIgnoreTruncate { diff --git a/pkg/tablecodec/tablecodec.go b/pkg/tablecodec/tablecodec.go index 1c3462b764c47..728059fa59ba2 100644 --- a/pkg/tablecodec/tablecodec.go +++ b/pkg/tablecodec/tablecodec.go @@ -398,7 +398,7 @@ func flatten(sc *stmtctx.StatementContext, data types.Datum, ret *types.Datum) e return nil case types.KindBinaryLiteral, types.KindMysqlBit: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - val, err := data.GetBinaryLiteral().ToInt(sc.TypeCtx) + val, err := data.GetBinaryLiteral().ToInt(sc.TypeCtx()) if err != nil { return errors.Trace(err) } diff --git a/pkg/tablecodec/tablecodec_test.go b/pkg/tablecodec/tablecodec_test.go index b237d594548de..1d2997367f666 100644 --- a/pkg/tablecodec/tablecodec_test.go +++ b/pkg/tablecodec/tablecodec_test.go @@ -552,6 +552,8 @@ func BenchmarkHasTablePrefixBuiltin(b *testing.B) { // Bench result: // BenchmarkEncodeValue 5000000 368 ns/op func BenchmarkEncodeValue(b *testing.B) { + sc := stmtctx.NewStmtCtx() + row := make([]types.Datum, 7) row[0] = types.NewIntDatum(100) row[1] = types.NewBytesDatum([]byte("abc")) @@ -565,7 +567,7 @@ func BenchmarkEncodeValue(b *testing.B) { for i := 0; i < b.N; i++ { for _, d := range row { encodedCol = encodedCol[:0] - _, err := EncodeValue(nil, encodedCol, d) + _, err := EncodeValue(sc, encodedCol, d) if err != nil { b.Fatal(err) } diff --git a/pkg/types/convert.go b/pkg/types/convert.go index 816c494ec4e70..4798103cf3654 100644 --- a/pkg/types/convert.go +++ b/pkg/types/convert.go @@ -344,7 +344,8 @@ func StrToDuration(sc *stmtctx.StatementContext, str string, fsp int) (d Duratio d, _, err = ParseDuration(sc, str, fsp) if ErrTruncatedWrongVal.Equal(err) { - err = sc.TypeCtx.HandleTruncate(err) + typeCtx := sc.TypeCtx() + err = typeCtx.HandleTruncate(err) } return d, t, true, errors.Trace(err) } @@ -571,13 +572,13 @@ func ConvertJSONToInt64(sc *stmtctx.StatementContext, j BinaryJSON, unsigned boo func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool, tp byte) (int64, error) { switch j.TypeCode { case JSONTypeCodeObject, JSONTypeCodeArray, JSONTypeCodeOpaque, JSONTypeCodeDate, JSONTypeCodeDatetime, JSONTypeCodeTimestamp, JSONTypeCodeDuration: - return 0, sc.TypeCtx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) + return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) case JSONTypeCodeLiteral: switch j.Value[0] { case JSONLiteralFalse: return 0, nil case JSONLiteralNil: - return 0, sc.TypeCtx.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) + return 0, sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", j.String())) default: return 1, nil } diff --git a/pkg/types/datum.go b/pkg/types/datum.go index 13a3bfa59a584..3a2524a9b19ab 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -633,7 +633,7 @@ func (d *Datum) SetValue(val interface{}, tp *types.FieldType) { func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collate.Collator) (int, error) { typeCtx := DefaultNoWarningContext if sc != nil { - typeCtx = sc.TypeCtx + typeCtx = sc.TypeCtx() } if d.k == KindMysqlJSON && ad.k != KindMysqlJSON { cmp, err := ad.Compare(sc, d, comparer) @@ -758,6 +758,8 @@ func (d *Datum) compareFloat64(ctx Context, f float64) (int, error) { } func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, comparer collate.Collator) (int, error) { + typeCtx := sc.TypeCtxOrDefault() + switch d.k { case KindNull, KindMinNotNull: return -1, nil @@ -767,7 +769,7 @@ func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, comparer c return comparer.Compare(d.GetString(), s), nil case KindMysqlDecimal: dec := new(MyDecimal) - err := sc.TypeCtx.HandleTruncate(dec.FromString(hack.Slice(s))) + err := typeCtx.HandleTruncate(dec.FromString(hack.Slice(s))) return d.GetMysqlDecimal().Compare(dec), errors.Trace(err) case KindMysqlTime: dt, err := ParseDatetime(sc, s) @@ -791,6 +793,8 @@ func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, comparer c } func (d *Datum) compareMysqlDecimal(sc *stmtctx.StatementContext, dec *MyDecimal) (int, error) { + typeCtx := sc.TypeCtxOrDefault() + switch d.k { case KindNull, KindMinNotNull: return -1, nil @@ -800,7 +804,7 @@ func (d *Datum) compareMysqlDecimal(sc *stmtctx.StatementContext, dec *MyDecimal return d.GetMysqlDecimal().Compare(dec), nil case KindString, KindBytes: dDec := new(MyDecimal) - err := sc.TypeCtx.HandleTruncate(dDec.FromString(d.GetBytes())) + err := typeCtx.HandleTruncate(dDec.FromString(d.GetBytes())) return dDec.Compare(dec), errors.Trace(err) default: dVal, err := d.ConvertTo(sc, NewFieldType(mysql.TypeNewDecimal)) @@ -1030,7 +1034,7 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) s string err error ) - ctx := sc.TypeCtx + ctx := sc.TypeCtx() switch d.k { case KindInt64: s = strconv.FormatInt(d.GetInt64(), 10) diff --git a/pkg/util/codec/codec.go b/pkg/util/codec/codec.go index cec50ac92cc51..a4b499297fddb 100644 --- a/pkg/util/codec/codec.go +++ b/pkg/util/codec/codec.go @@ -110,7 +110,7 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab b = append(b, decimalFlag) b, err = EncodeDecimal(b, vals[i].GetMysqlDecimal(), vals[i].Length(), vals[i].Frac()) if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } else if terror.ErrorEqual(err, types.ErrOverflow) { err = sc.HandleOverflow(err, err) } @@ -1260,7 +1260,7 @@ func HashGroupKey(sc *stmtctx.StatementContext, n int, col *chunk.Column, buf [] buf[i] = append(buf[i], decimalFlag) buf[i], err = EncodeDecimal(buf[i], &ds[i], ft.GetFlen(), ft.GetDecimal()) if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } else if terror.ErrorEqual(err, types.ErrOverflow) { err = sc.HandleOverflow(err, err) } diff --git a/pkg/util/ranger/points.go b/pkg/util/ranger/points.go index a637dab5578ef..6830b953dbf0c 100644 --- a/pkg/util/ranger/points.go +++ b/pkg/util/ranger/points.go @@ -209,7 +209,7 @@ func (r *builder) buildFromConstant(expr *expression.Constant) []*point { return nil } - val, err := dt.ToBool(r.sc.TypeCtx) + val, err := dt.ToBool(r.sc.TypeCtx()) if err != nil { r.err = err return nil diff --git a/pkg/util/rowcodec/encoder.go b/pkg/util/rowcodec/encoder.go index 14bb36425d219..e366454092846 100644 --- a/pkg/util/rowcodec/encoder.go +++ b/pkg/util/rowcodec/encoder.go @@ -205,7 +205,7 @@ func encodeValueDatum(sc *stmtctx.StatementContext, d *types.Datum, buffer []byt buffer, err = codec.EncodeDecimal(buffer, d.GetMysqlDecimal(), d.Length(), d.Frac()) if err != nil && sc != nil { if terror.ErrorEqual(err, types.ErrTruncated) { - err = sc.TypeCtx.HandleTruncate(err) + err = sc.HandleTruncate(err) } else if terror.ErrorEqual(err, types.ErrOverflow) { err = sc.HandleOverflow(err, err) }