Skip to content

Commit

Permalink
executor: avoid sum from avg overflow (#30010) (#30378)
Browse files Browse the repository at this point in the history
close #29952
  • Loading branch information
ti-srebot authored Sep 19, 2022
1 parent ed72864 commit 1582c54
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 66 deletions.
5 changes: 2 additions & 3 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ type baseAggFunc struct {
// used to append the final result of this function.
ordinal int

// frac stores digits of the fractional part of decimals,
// which makes the decimal be the result of type inferring.
frac int
// retTp means the target type of the final agg should return.
retTp *types.FieldType
}

func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) {
Expand Down
32 changes: 5 additions & 27 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"fmt"
"strconv"

"github.com/cznic/mathutil"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
Expand Down Expand Up @@ -194,6 +193,7 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}

// If HasDistinct and mode is CompleteMode or Partial1Mode, we should
Expand Down Expand Up @@ -253,13 +253,9 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -287,16 +283,8 @@ func buildAvg(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if len(base.args) == 2 {
frac = base.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
Expand Down Expand Up @@ -340,13 +328,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down Expand Up @@ -392,16 +375,11 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
isMax: isMax,
collator: collate.GetCollator(aggFuncDesc.RetTp.Collate),
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down
20 changes: 18 additions & 2 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -71,7 +73,14 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down Expand Up @@ -259,7 +268,14 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -475,7 +477,14 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of first_row should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
10 changes: 9 additions & 1 deletion executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -813,7 +814,14 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of max or min should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -168,7 +170,14 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of sum should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
21 changes: 19 additions & 2 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -740,11 +740,28 @@ func (s *tiflashTestSuite) TestUnionWithEmptyDualTable(c *C) {
func (s *tiflashTestSuite) TestAvgOverflow(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
// avg int
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a decimal(1,0))")
tk.MustExec("alter table t set tiflash replica 1")
tb := testGetTableByName(c, tk.Se, "test", "t")
err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into t values(9)")
for i := 0; i < 16; i++ {
tk.MustExec("insert into t select * from t")
}
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
tk.MustExec("set @@session.tidb_enforce_mpp=ON")
tk.MustQuery("select avg(a) from t group by a").Check(testkit.Rows("9.0000"))
tk.MustExec("drop table if exists t")

// avg decimal
tk.MustExec("drop table if exists td;")
tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));")
tk.MustExec("alter table td set tiflash replica 1")
tb := testGetTableByName(c, tk.Se, "test", "td")
err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
tb = testGetTableByName(c, tk.Se, "test", "td")
err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into td values (null, 22876);")
tk.MustExec("insert into td values (9220557287087669248, 32767);")
Expand Down
32 changes: 17 additions & 15 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type baseFuncDesc struct {

func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
err := b.typeInfer(ctx)
err := b.TypeInfer(ctx)
return b, err
}

Expand Down Expand Up @@ -83,8 +83,8 @@ func (a *baseFuncDesc) String() string {
return buffer.String()
}

// typeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error {
// TypeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count(ctx)
Expand Down Expand Up @@ -206,6 +206,14 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
types.SetBinChsClnFlag(a.RetTp)
}

// TypeInfer4AvgSum infers the type of sum from avg, which should extend the precision of decimal
// compatible with mysql.
func (a *baseFuncDesc) TypeInfer4AvgSum(avgRetType *types.FieldType) {
if avgRetType.Tp == mysql.TypeNewDecimal {
a.RetTp.Flen = mathutil.Min(mysql.MaxDecimalWidth, a.RetTp.Flen+22)
}
}

// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
Expand Down Expand Up @@ -245,6 +253,12 @@ func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {

a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0
// TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i])
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}

}

func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
Expand Down Expand Up @@ -368,18 +382,6 @@ var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncJsonObjectAgg: {},
}

// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions
// with a cast as decimal function. See issue #19426
func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) {
if a.Name == ast.AggFuncGroupConcat {
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}
}
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if len(a.Args) == 0 {
Expand Down
1 change: 0 additions & 1 deletion planner/core/rule_inject_extra_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll {
// since the types of the args are already the expected.
func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) {
for i := range aggFuncs {
aggFuncs[i].WrapCastAsDecimalForAggArgs(sctx)
if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode {
aggFuncs[i].WrapCastForAggArgs(sctx)
}
Expand Down
31 changes: 21 additions & 10 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1628,12 +1628,18 @@ func BuildFinalModeAggregation(
if aggFunc.Name == ast.AggFuncAvg {
cntAgg := aggFunc.Clone()
cntAgg.Name = ast.AggFuncCount
cntAgg.RetTp = partial.Schema.Columns[partialCursor-2].GetType()
cntAgg.RetTp.Flag = aggFunc.RetTp.Flag
err := cntAgg.TypeInfer(sctx)
if err != nil { // must not happen
partial = nil
final = original
return
}
partial.Schema.Columns[partialCursor-2].RetType = cntAgg.RetTp
// we must call deep clone in this case, to avoid sharing the arguments.
sumAgg := aggFunc.Clone()
sumAgg.Name = ast.AggFuncSum
sumAgg.RetTp = partial.Schema.Columns[partialCursor-1].GetType()
sumAgg.TypeInfer4AvgSum(sumAgg.RetTp)
partial.Schema.Columns[partialCursor-1].RetType = sumAgg.RetTp
partial.AggFuncs = append(partial.AggFuncs, cntAgg, sumAgg)
} else if aggFunc.Name == ast.AggFuncApproxCountDistinct {
newAggFunc := aggFunc.Clone()
Expand Down Expand Up @@ -1673,33 +1679,38 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection {
newSchema.Keys = p.schema.Keys
newSchema.UniqueKeys = p.schema.UniqueKeys
newAggFuncs := make([]*aggregation.AggFuncDesc, 0, 2*len(p.AggFuncs))
ft := types.NewFieldType(mysql.TypeLonglong)
ft.Flen, ft.Decimal, ft.Charset, ft.Collate = 20, 0, charset.CharsetBin, charset.CollationBin
exprs := make([]expression.Expression, 0, 2*len(p.schema.Columns))
// add agg functions schema
for i, aggFunc := range p.AggFuncs {
if aggFunc.Name == ast.AggFuncAvg {
// inset a count(column)
avgCount := aggFunc.Clone()
avgCount.Name = ast.AggFuncCount
err := avgCount.TypeInfer(p.ctx)
if err != nil { // must not happen
return nil
}
newAggFuncs = append(newAggFuncs, avgCount)
avgCount.RetTp = ft
avgCountCol := &expression.Column{
UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(),
RetType: ft,
RetType: avgCount.RetTp,
}
newSchema.Append(avgCountCol)
// insert a sum(column)
avgSum := aggFunc.Clone()
avgSum.Name = ast.AggFuncSum
avgSum.TypeInfer4AvgSum(avgSum.RetTp)
newAggFuncs = append(newAggFuncs, avgSum)
newSchema.Append(p.schema.Columns[i])
avgSumCol := p.schema.Columns[i]
avgSumCol := &expression.Column{
UniqueID: p.schema.Columns[i].UniqueID,
RetType: avgSum.RetTp,
}
newSchema.Append(avgSumCol)
// avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end)
eq := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), avgCountCol, expression.NewZero())
caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, avgCountCol.RetType, eq, expression.NewOne(), avgCountCol)
divide := expression.NewFunctionInternal(p.ctx, ast.Div, avgSumCol.RetType, avgSumCol, caseWhen)
divide.(*expression.ScalarFunction).RetType = avgSumCol.RetType
divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType
exprs = append(exprs, divide)
} else {
newAggFuncs = append(newAggFuncs, aggFunc)
Expand Down
Loading

0 comments on commit 1582c54

Please sign in to comment.