Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into 0703
Browse files Browse the repository at this point in the history
  • Loading branch information
spongedu committed Jul 3, 2018
2 parents 13f1cfd + ef1b9df commit a79d5ea
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 91 deletions.
12 changes: 9 additions & 3 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ import (

// All the AggFunc implementations are listed here for navigation.
var (
// All the AggFunc implementations for "COUNT" are listed here.
// All the AggFunc implementations for "SUM" are listed here.
// All the AggFunc implementations for "AVG" are listed here.
// All the AggFunc implementations for "COUNT" are listed here.
// All the AggFunc implementations for "SUM" are listed here.
// All the AggFunc implementations for "AVG" are listed here.
_ AggFunc = (*avgOriginal4Decimal)(nil)
_ AggFunc = (*avgPartial4Decimal)(nil)

_ AggFunc = (*avgOriginal4Float64)(nil)
_ AggFunc = (*avgPartial4Float64)(nil)

// All the AggFunc implementations for "FIRSTROW" are listed here.
// All the AggFunc implementations for "MAX" are listed here.
// All the AggFunc implementations for "MIN" are listed here.
Expand Down
37 changes: 37 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package aggfuncs
import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/mysql"
)

// Build is used to build a specific AggFunc implementation according to the
Expand Down Expand Up @@ -58,6 +59,42 @@ func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {

// buildCount builds the AggFunc implementation for function "AVG".
func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
case aggregation.DedupMode:
return nil // not implemented yet.

// Build avg functions which consume the original data and update their
// partial results.
case aggregation.CompleteMode, aggregation.Partial1Mode:
switch aggFuncDesc.Args[0].GetType().Tp {
case mysql.TypeNewDecimal:
if aggFuncDesc.HasDistinct {
return nil // not implemented yet.
}
return &avgOriginal4Decimal{baseAvgDecimal{base}}
case mysql.TypeFloat, mysql.TypeDouble:
if aggFuncDesc.HasDistinct {
return nil // not implemented yet.
}
return &avgOriginal4Float64{baseAvgFloat64{base}}
}

// Build avg functions which consume the partial result of other avg
// functions and update their partial results.
case aggregation.Partial2Mode, aggregation.FinalMode:
switch aggFuncDesc.Args[1].GetType().Tp {
case mysql.TypeNewDecimal:
return &avgPartial4Decimal{baseAvgDecimal{base}}
case mysql.TypeDouble:
return &avgPartial4Float64{baseAvgFloat64{base}}
}
}
return nil
}

Expand Down
207 changes: 207 additions & 0 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/juju/errors"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
)

// All the following avg function implementations return the decimal result,
// which store the partial results in "partialResult4AvgDecimal".
//
// "baseAvgDecimal" is wrapped by:
// - "avgOriginal4Decimal"
// - "avgPartial4Decimal"
type baseAvgDecimal struct {
baseAggFunc
}

type partialResult4AvgDecimal struct {
sum types.MyDecimal
count int64
}

func (e *baseAvgDecimal) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4AvgDecimal{})
}

func (e *baseAvgDecimal) ResetPartialResult(pr PartialResult) {
p := (*partialResult4AvgDecimal)(pr)
p.sum = *types.NewDecFromInt(0)
p.count = int64(0)
}

func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4AvgDecimal)(pr)
if p.count == 0 {
chk.AppendNull(e.ordinal)
return nil
}
decimalCount := types.NewDecFromInt(p.count)
finalResult := new(types.MyDecimal)
err := types.DecimalDiv(&p.sum, decimalCount, finalResult, types.DivFracIncr)
if err != nil {
return errors.Trace(err)
}
chk.AppendMyDecimal(e.ordinal, finalResult)
return nil
}

type avgOriginal4Decimal struct {
baseAvgDecimal
}

func (e *avgOriginal4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4AvgDecimal)(pr)
newSum := new(types.MyDecimal)
for _, row := range rowsInGroup {
input, isNull, err := e.args[0].EvalDecimal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

err = types.DecimalAdd(&p.sum, input, newSum)
if err != nil {
return errors.Trace(err)
}
p.sum = *newSum
p.count++
}
return nil
}

type avgPartial4Decimal struct {
baseAvgDecimal
}

func (e *avgPartial4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4AvgDecimal)(pr)
newSum := new(types.MyDecimal)
for _, row := range rowsInGroup {
inputSum, isNull, err := e.args[1].EvalDecimal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

inputCount, isNull, err := e.args[0].EvalInt(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

err = types.DecimalAdd(&p.sum, inputSum, newSum)
if err != nil {
return errors.Trace(err)
}
p.sum = *newSum
p.count += inputCount
}
return nil
}

// All the following avg function implementations return the float64 result,
// which store the partial results in "partialResult4AvgFloat64".
//
// "baseAvgFloat64" is wrapped by:
// - "avgOriginal4Float64"
// - "avgPartial4Float64"
type baseAvgFloat64 struct {
baseAggFunc
}

type partialResult4AvgFloat64 struct {
sum float64
count int64
}

func (e *baseAvgFloat64) AllocPartialResult() PartialResult {
return (PartialResult)(&partialResult4AvgFloat64{})
}

func (e *baseAvgFloat64) ResetPartialResult(pr PartialResult) {
p := (*partialResult4AvgFloat64)(pr)
p.sum = 0
p.count = 0
}

func (e *baseAvgFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4AvgFloat64)(pr)
if p.count == 0 {
chk.AppendNull(e.ordinal)
} else {
chk.AppendFloat64(e.ordinal, p.sum/float64(p.count))
}
return nil
}

type avgOriginal4Float64 struct {
baseAvgFloat64
}

func (e *avgOriginal4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4AvgFloat64)(pr)
for _, row := range rowsInGroup {
input, isNull, err := e.args[0].EvalReal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

p.sum += input
p.count++
}
return nil
}

type avgPartial4Float64 struct {
baseAvgFloat64
}

func (e *avgPartial4Float64) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4AvgFloat64)(pr)
for _, row := range rowsInGroup {
inputSum, isNull, err := e.args[1].EvalReal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

inputCount, isNull, err := e.args[0].EvalInt(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}

p.sum += inputSum
p.count += inputCount
}
return nil
}
Loading

0 comments on commit a79d5ea

Please sign in to comment.