diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index e608d910ba821..32fb746a25472 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -187,9 +187,8 @@ type baseAggFunc struct { // used to append the final result of this function. ordinal int - // frac stores digits of the fractional part of decimals, - // which makes the decimal be the result of type inferring. - frac int + // retTp means the target type of the final agg should return. + retTp *types.FieldType } func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) { diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index d562a9241822a..b105e3cb46f0c 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -17,7 +17,6 @@ import ( "fmt" "strconv" - "github.com/cznic/mathutil" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" @@ -194,6 +193,7 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { base := baseAggFunc{ args: aggFuncDesc.Args, ordinal: ordinal, + retTp: aggFuncDesc.RetTp, } // If HasDistinct and mode is CompleteMode or Partial1Mode, we should @@ -253,13 +253,9 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi baseAggFunc: baseAggFunc{ args: aggFuncDesc.Args, ordinal: ordinal, + retTp: aggFuncDesc.RetTp, }, } - frac := base.args[0].GetType().Decimal - if frac == -1 { - frac = mysql.MaxDecimalScale - } - base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) switch aggFuncDesc.Mode { case aggregation.DedupMode: return nil @@ -287,16 +283,8 @@ func buildAvg(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi base := baseAggFunc{ args: aggFuncDesc.Args, ordinal: ordinal, + retTp: aggFuncDesc.RetTp, } - frac := base.args[0].GetType().Decimal - if len(base.args) == 2 { - frac = base.args[1].GetType().Decimal - } - if frac == -1 { - frac = mysql.MaxDecimalScale - } - base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) - switch aggFuncDesc.Mode { // Build avg functions which consume the original data and remove the // duplicated input of the same group. @@ -340,13 +328,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { base := baseAggFunc{ args: aggFuncDesc.Args, ordinal: ordinal, + retTp: aggFuncDesc.RetTp, } - frac := base.args[0].GetType().Decimal - if frac == -1 { - frac = mysql.MaxDecimalScale - } - base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) - evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp if fieldType.Tp == mysql.TypeBit { evalType = types.ETString @@ -392,16 +375,11 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) baseAggFunc: baseAggFunc{ args: aggFuncDesc.Args, ordinal: ordinal, + retTp: aggFuncDesc.RetTp, }, isMax: isMax, collator: collate.GetCollator(aggFuncDesc.RetTp.Collate), } - frac := base.args[0].GetType().Decimal - if frac == -1 { - frac = mysql.MaxDecimalScale - } - base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) - evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp if fieldType.Tp == mysql.TypeBit { evalType = types.ETString diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index a62565ab43fa6..02216e9ba335f 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -16,6 +16,8 @@ package aggfuncs import ( "unsafe" + "github.com/pingcap/errors" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -71,7 +73,14 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par if err != nil { return err } - err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven) + if e.retTp == nil { + return errors.New("e.retTp of avg should not be nil") + } + frac := e.retTp.Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err = finalResult.Round(finalResult, frac, types.ModeHalfEven) if err != nil { return err } @@ -259,7 +268,14 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co if err != nil { return err } - err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven) + if e.retTp == nil { + return errors.New("e.retTp of avg should not be nil") + } + frac := e.retTp.Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err = finalResult.Round(finalResult, frac, types.ModeHalfEven) if err != nil { return err } diff --git a/executor/aggfuncs/func_first_row.go b/executor/aggfuncs/func_first_row.go index 99c3dbade1439..6ee037bb0b8ac 100644 --- a/executor/aggfuncs/func_first_row.go +++ b/executor/aggfuncs/func_first_row.go @@ -16,6 +16,8 @@ package aggfuncs import ( "unsafe" + "github.com/pingcap/errors" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" @@ -475,7 +477,14 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P chk.AppendNull(e.ordinal) return nil } - err := p.val.Round(&p.val, e.frac, types.ModeHalfEven) + if e.retTp == nil { + return errors.New("e.retTp of first_row should not be nil") + } + frac := e.retTp.Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err := p.val.Round(&p.val, frac, types.ModeHalfEven) if err != nil { return err } diff --git a/executor/aggfuncs/func_max_min.go b/executor/aggfuncs/func_max_min.go index 51055c6e24b5e..4900bdd901d1d 100644 --- a/executor/aggfuncs/func_max_min.go +++ b/executor/aggfuncs/func_max_min.go @@ -17,6 +17,7 @@ import ( "unsafe" "github.com/pingcap/errors" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" @@ -813,7 +814,14 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par chk.AppendNull(e.ordinal) return nil } - err := p.val.Round(&p.val, e.frac, types.ModeHalfEven) + if e.retTp == nil { + return errors.New("e.retTp of max or min should not be nil") + } + frac := e.retTp.Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err := p.val.Round(&p.val, frac, types.ModeHalfEven) if err != nil { return err } diff --git a/executor/aggfuncs/func_sum.go b/executor/aggfuncs/func_sum.go index 77cc6745ddf65..267142703d8e6 100644 --- a/executor/aggfuncs/func_sum.go +++ b/executor/aggfuncs/func_sum.go @@ -16,6 +16,8 @@ package aggfuncs import ( "unsafe" + "github.com/pingcap/errors" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -168,7 +170,14 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia chk.AppendNull(e.ordinal) return nil } - err := p.val.Round(&p.val, e.frac, types.ModeHalfEven) + if e.retTp == nil { + return errors.New("e.retTp of sum should not be nil") + } + frac := e.retTp.Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err := p.val.Round(&p.val, frac, types.ModeHalfEven) if err != nil { return err } diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index df69d08f21a68..d1b3ab06507b6 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -740,11 +740,28 @@ func (s *tiflashTestSuite) TestUnionWithEmptyDualTable(c *C) { func (s *tiflashTestSuite) TestAvgOverflow(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + // avg int + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a decimal(1,0))") + tk.MustExec("alter table t set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "t") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into t values(9)") + for i := 0; i < 16; i++ { + tk.MustExec("insert into t select * from t") + } + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("set @@session.tidb_enforce_mpp=ON") + tk.MustQuery("select avg(a) from t group by a").Check(testkit.Rows("9.0000")) + tk.MustExec("drop table if exists t") + + // avg decimal tk.MustExec("drop table if exists td;") tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));") tk.MustExec("alter table td set tiflash replica 1") - tb := testGetTableByName(c, tk.Se, "test", "td") - err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + tb = testGetTableByName(c, tk.Se, "test", "td") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) c.Assert(err, IsNil) tk.MustExec("insert into td values (null, 22876);") tk.MustExec("insert into td values (9220557287087669248, 32767);") diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 78f260552253c..3cd9e60665484 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -42,7 +42,7 @@ type baseFuncDesc struct { func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) { b := baseFuncDesc{Name: strings.ToLower(name), Args: args} - err := b.typeInfer(ctx) + err := b.TypeInfer(ctx) return b, err } @@ -83,8 +83,8 @@ func (a *baseFuncDesc) String() string { return buffer.String() } -// typeInfer infers the arguments and return types of an function. -func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error { +// TypeInfer infers the arguments and return types of an function. +func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error { switch a.Name { case ast.AggFuncCount: a.typeInfer4Count(ctx) @@ -206,6 +206,14 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { types.SetBinChsClnFlag(a.RetTp) } +// TypeInfer4AvgSum infers the type of sum from avg, which should extend the precision of decimal +// compatible with mysql. +func (a *baseFuncDesc) TypeInfer4AvgSum(avgRetType *types.FieldType) { + if avgRetType.Tp == mysql.TypeNewDecimal { + a.RetTp.Flen = mathutil.Min(mysql.MaxDecimalWidth, a.RetTp.Flen+22) + } +} + // typeInfer4Avg should returns a "decimal", otherwise it returns a "double". // Because child returns integer or decimal type. func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { @@ -245,6 +253,12 @@ func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) { a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0 // TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i]) + for i := 0; i < len(a.Args)-1; i++ { + if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal { + a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp) + } + } + } func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) { @@ -368,18 +382,6 @@ var noNeedCastAggFuncs = map[string]struct{}{ ast.AggFuncJsonObjectAgg: {}, } -// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions -// with a cast as decimal function. See issue #19426 -func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) { - if a.Name == ast.AggFuncGroupConcat { - for i := 0; i < len(a.Args)-1; i++ { - if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal { - a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp) - } - } - } -} - // WrapCastForAggArgs wraps the args of an aggregate function with a cast function. func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) { if len(a.Args) == 0 { diff --git a/planner/core/rule_inject_extra_projection.go b/planner/core/rule_inject_extra_projection.go index 911c531ceb4f0..2e9944d94bd8b 100644 --- a/planner/core/rule_inject_extra_projection.go +++ b/planner/core/rule_inject_extra_projection.go @@ -105,7 +105,6 @@ func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll { // since the types of the args are already the expected. func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) { for i := range aggFuncs { - aggFuncs[i].WrapCastAsDecimalForAggArgs(sctx) if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode { aggFuncs[i].WrapCastForAggArgs(sctx) } diff --git a/planner/core/task.go b/planner/core/task.go index b507baaa2d323..2853b6a79ff7b 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1628,12 +1628,18 @@ func BuildFinalModeAggregation( if aggFunc.Name == ast.AggFuncAvg { cntAgg := aggFunc.Clone() cntAgg.Name = ast.AggFuncCount - cntAgg.RetTp = partial.Schema.Columns[partialCursor-2].GetType() - cntAgg.RetTp.Flag = aggFunc.RetTp.Flag + err := cntAgg.TypeInfer(sctx) + if err != nil { // must not happen + partial = nil + final = original + return + } + partial.Schema.Columns[partialCursor-2].RetType = cntAgg.RetTp // we must call deep clone in this case, to avoid sharing the arguments. sumAgg := aggFunc.Clone() sumAgg.Name = ast.AggFuncSum - sumAgg.RetTp = partial.Schema.Columns[partialCursor-1].GetType() + sumAgg.TypeInfer4AvgSum(sumAgg.RetTp) + partial.Schema.Columns[partialCursor-1].RetType = sumAgg.RetTp partial.AggFuncs = append(partial.AggFuncs, cntAgg, sumAgg) } else if aggFunc.Name == ast.AggFuncApproxCountDistinct { newAggFunc := aggFunc.Clone() @@ -1673,8 +1679,6 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { newSchema.Keys = p.schema.Keys newSchema.UniqueKeys = p.schema.UniqueKeys newAggFuncs := make([]*aggregation.AggFuncDesc, 0, 2*len(p.AggFuncs)) - ft := types.NewFieldType(mysql.TypeLonglong) - ft.Flen, ft.Decimal, ft.Charset, ft.Collate = 20, 0, charset.CharsetBin, charset.CollationBin exprs := make([]expression.Expression, 0, 2*len(p.schema.Columns)) // add agg functions schema for i, aggFunc := range p.AggFuncs { @@ -1682,24 +1686,31 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { // inset a count(column) avgCount := aggFunc.Clone() avgCount.Name = ast.AggFuncCount + err := avgCount.TypeInfer(p.ctx) + if err != nil { // must not happen + return nil + } newAggFuncs = append(newAggFuncs, avgCount) - avgCount.RetTp = ft avgCountCol := &expression.Column{ UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: ft, + RetType: avgCount.RetTp, } newSchema.Append(avgCountCol) // insert a sum(column) avgSum := aggFunc.Clone() avgSum.Name = ast.AggFuncSum + avgSum.TypeInfer4AvgSum(avgSum.RetTp) newAggFuncs = append(newAggFuncs, avgSum) - newSchema.Append(p.schema.Columns[i]) - avgSumCol := p.schema.Columns[i] + avgSumCol := &expression.Column{ + UniqueID: p.schema.Columns[i].UniqueID, + RetType: avgSum.RetTp, + } + newSchema.Append(avgSumCol) // avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end) eq := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), avgCountCol, expression.NewZero()) caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, avgCountCol.RetType, eq, expression.NewOne(), avgCountCol) divide := expression.NewFunctionInternal(p.ctx, ast.Div, avgSumCol.RetType, avgSumCol, caseWhen) - divide.(*expression.ScalarFunction).RetType = avgSumCol.RetType + divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType exprs = append(exprs, divide) } else { newAggFuncs = append(newAggFuncs, aggFunc) diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index b3b14f50d9064..3a3c35e3dc1b2 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -1222,7 +1222,7 @@ "StreamAgg 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader 1.00 root data:StreamAgg", " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7, funcs:sum(Column#10)->Column#8", - " └─Projection 10000.00 batchCop[tiflash] test.t.a, cast(test.t.a, decimal(15,4) BINARY)->Column#10", + " └─Projection 10000.00 batchCop[tiflash] test.t.a, cast(test.t.a, decimal(37,4) BINARY)->Column#10", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null @@ -1233,7 +1233,7 @@ "StreamAgg 1.00 root funcs:avg(Column#7, Column#8)->Column#4", "└─TableReader 1.00 root data:StreamAgg", " └─StreamAgg 1.00 batchCop[tiflash] funcs:count(Column#9)->Column#7, funcs:sum(Column#10)->Column#8", - " └─Projection 10000.00 batchCop[tiflash] test.t.a, cast(test.t.a, decimal(15,4) BINARY)->Column#10", + " └─Projection 10000.00 batchCop[tiflash] test.t.a, cast(test.t.a, decimal(37,4) BINARY)->Column#10", " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null diff --git a/tools/check/check-timeout.go b/tools/check/check-timeout.go index 76105ef6543d2..6fe07257aa76d 100644 --- a/tools/check/check-timeout.go +++ b/tools/check/check-timeout.go @@ -192,7 +192,7 @@ func main() { fmt.Println("parser line error:", err) os.Exit(-1) } - if dur > 5*time.Second { + if dur > 60*time.Second { if inAllowList(testName) { continue }