Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: avoid sum from avg overflow (#30010) #30383

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ type baseAggFunc struct {
// used to append the final result of this function.
ordinal int

// frac stores digits of the fractional part of decimals,
// which makes the decimal be the result of type inferring.
frac int
// retTp means the target type of the final agg should return.
retTp *types.FieldType
}

func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) error {
Expand Down
34 changes: 8 additions & 26 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ import (
"fmt"
"strconv"

<<<<<<< HEAD
"github.com/cznic/mathutil"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
=======
>>>>>>> 9aa756336... executor: avoid sum from avg overflow (#30010)
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -186,6 +189,7 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}

// If HasDistinct and mode is CompleteMode or Partial1Mode, we should
Expand Down Expand Up @@ -245,13 +249,9 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -279,16 +279,8 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if len(base.args) == 2 {
frac = base.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
Expand Down Expand Up @@ -329,13 +321,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down Expand Up @@ -381,15 +368,10 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
isMax: isMax,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down
25 changes: 23 additions & 2 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
package aggfuncs

import (
<<<<<<< HEAD
=======
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
>>>>>>> 9aa756336... executor: avoid sum from avg overflow (#30010)
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -58,7 +65,14 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down Expand Up @@ -206,7 +220,14 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
16 changes: 15 additions & 1 deletion executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
package aggfuncs

import (
<<<<<<< HEAD
=======
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
>>>>>>> 9aa756336... executor: avoid sum from avg overflow (#30010)
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -444,7 +451,14 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of first_row should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
16 changes: 15 additions & 1 deletion executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
package aggfuncs

import (
<<<<<<< HEAD
=======
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
>>>>>>> 9aa756336... executor: avoid sum from avg overflow (#30010)
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -361,7 +368,14 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of max or min should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
16 changes: 15 additions & 1 deletion executor/aggfuncs/func_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
package aggfuncs

import (
<<<<<<< HEAD
=======
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
>>>>>>> 9aa756336... executor: avoid sum from avg overflow (#30010)
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -153,7 +160,14 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of sum should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
39 changes: 39 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1212,3 +1212,42 @@ func (s *testSuiteAgg) TestIssue23277(c *C) {
tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000"))
tk.MustExec("drop table t;")
}
<<<<<<< HEAD
=======

func TestAvgDecimal(t *testing.T) {
t.Parallel()
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test;")
tk.MustExec("drop table if exists td;")
tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));")
tk.MustExec("insert into td values (null, 22876);")
tk.MustExec("insert into td values (9220557287087669248, 32767);")
tk.MustExec("insert into td values (28030, 32767);")
tk.MustExec("insert into td values (-3309864251140603904,32767);")
tk.MustExec("insert into td values (4,0);")
tk.MustExec("insert into td values (null,0);")
tk.MustExec("insert into td values (4,-23828);")
tk.MustExec("insert into td values (54720,32767);")
tk.MustExec("insert into td values (0,29815);")
tk.MustExec("insert into td values (10017,-32661);")
tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260"))
tk.MustExec("drop table td;")
}

// https://github.com/pingcap/tidb/issues/23314
func TestIssue23314(t *testing.T) {
t.Parallel()
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(col1 time(2) NOT NULL)")
tk.MustExec("insert into t1 values(\"16:40:20.01\")")
res := tk.MustQuery("select col1 from t1 group by col1")
res.Check(testkit.Rows("16:40:20.01"))
}
>>>>>>> 9aa756336... executor: avoid sum from avg overflow (#30010)
Loading