Skip to content

Commit

Permalink
planner: fix bug that aggregate push down may generate wrong plan for…
Browse files Browse the repository at this point in the history
… outer joins (pingcap#34468) (pingcap#34647)

close pingcap#34465
  • Loading branch information
ti-srebot authored Jun 23, 2022
1 parent 191b5f0 commit 3279c08
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 6 deletions.
42 changes: 42 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,48 @@ func (s *testIntegrationSuite) TestBitColErrorMessage(c *C) {
tk.MustGetErrCode("create table bit_col_t (a bit(65))", mysql.ErrTooBigDisplaywidth)
}

func (s *testIntegrationSuite) TestAggPushDownLeftJoin(c *C) {
tk := testkit.NewTestKit(c, s.store)

tk.MustExec("use test")
tk.MustExec("drop table if exists customer")
tk.MustExec("create table customer (C_CUSTKEY bigint(20) NOT NULL, C_NAME varchar(25) NOT NULL, " +
"C_ADDRESS varchar(25) NOT NULL, PRIMARY KEY (`C_CUSTKEY`) /*T![clustered_index] CLUSTERED */)")
tk.MustExec("drop table if exists orders")
tk.MustExec("create table orders (O_ORDERKEY bigint(20) NOT NULL, O_CUSTKEY bigint(20) NOT NULL, " +
"O_TOTALPRICE decimal(15,2) NOT NULL, PRIMARY KEY (`O_ORDERKEY`) /*T![clustered_index] CLUSTERED */)")
tk.MustExec("insert into customer values (6, \"xiao zhang\", \"address1\");")
tk.MustExec("set @@tidb_opt_agg_push_down=1;")

tk.MustQuery("select c_custkey, count(o_orderkey) as c_count from customer left outer join orders " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0"))
tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from customer left outer join orders " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows(
"Projection 10000.00 root test.customer.c_custkey, Column#7",
"└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey",
" └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]",
" ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" │ └─TableReader 8000.00 root data:HashAgg",
" │ └─HashAgg 8000.00 cop[tikv] group by:test.orders.o_custkey, funcs:count(test.orders.o_orderkey)->Column#9",
" │ └─TableFullScan 10000.00 cop[tikv] table:orders keep order:false, stats:pseudo",
" └─TableReader(Probe) 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:customer keep order:false, stats:pseudo"))

tk.MustQuery("select c_custkey, count(o_orderkey) as c_count from orders right outer join customer " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0"))
tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from orders right outer join customer " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows(
"Projection 10000.00 root test.customer.c_custkey, Column#7",
"└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey",
" └─HashJoin 10000.00 root right outer join, equal:[eq(test.orders.o_custkey, test.customer.c_custkey)]",
" ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" │ └─TableReader 8000.00 root data:HashAgg",
" │ └─HashAgg 8000.00 cop[tikv] group by:test.orders.o_custkey, funcs:count(test.orders.o_orderkey)->Column#9",
" │ └─TableFullScan 10000.00 cop[tikv] table:orders keep order:false, stats:pseudo",
" └─TableReader(Probe) 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:customer keep order:false, stats:pseudo"))
}

func (s *testIntegrationSuite) TestPushLimitDownIndexLookUpReader(c *C) {
tk := testkit.NewTestKit(c, s.store)

Expand Down
4 changes: 3 additions & 1 deletion planner/core/rule_aggregation_elimination.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ func ConvertAggToProj(agg *LogicalAggregation, schema *expression.Schema) (bool,
func rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) (bool, expression.Expression) {
switch aggFunc.Name {
case ast.AggFuncCount:
if aggFunc.Mode == aggregation.FinalMode {
if aggFunc.Mode == aggregation.FinalMode &&
len(aggFunc.Args) == 1 &&
mysql.HasNotNullFlag(aggFunc.Args[0].GetType().Flag) {
return true, wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp)
}
return true, rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp)
Expand Down
33 changes: 28 additions & 5 deletions planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ func (a *aggregationPushDownSolver) checkValidJoin(join *LogicalJoin) bool {

// decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently
// there are no differences between partial mode and complete mode, so we can confuse them.
func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc, schema *expression.Schema) ([]*aggregation.AggFuncDesc, *expression.Schema) {
func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc,
schema *expression.Schema, nullGenerating bool) ([]*aggregation.AggFuncDesc, *expression.Schema) {
// Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case.
result := []*aggregation.AggFuncDesc{aggFunc.Clone()}
for _, aggFunc := range result {
Expand All @@ -197,7 +198,21 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a
RetType: aggFunc.RetTp,
})
}
aggFunc.Args = expression.Column2Exprs(schema.Columns[schema.Len()-len(result):])
cols := schema.Columns[schema.Len()-len(result):]
aggFunc.Args = make([]expression.Expression, 0, len(cols))
// if the partial aggregation is on the null generating side, we have to clear the NOT NULL flag
// for the final aggregate functions' arguments
for _, col := range cols {
if nullGenerating {
arg := *col
newFieldType := *arg.RetType
newFieldType.Flag &= ^mysql.NotNullFlag
arg.RetType = &newFieldType
aggFunc.Args = append(aggFunc.Args, &arg)
} else {
aggFunc.Args = append(aggFunc.Args, col)
}
}
aggFunc.Mode = aggregation.FinalMode
return result, schema
}
Expand All @@ -220,7 +235,9 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg
return child, nil
}
}
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, aggHints, blockOffset)
nullGenerating := (join.JoinType == LeftOuterJoin && childIdx == 1) ||
(join.JoinType == RightOuterJoin && childIdx == 0)
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, aggHints, blockOffset, nullGenerating)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -266,7 +283,8 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation.
// TODO:
// 1. https://github.com/pingcap/tidb/issues/16355, push avg & distinct functions across join
// 2. remove this method and use splitPartialAgg instead for clean code.
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, aggHints aggHintInfo, blockOffset int) (*LogicalAggregation, error) {
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc,
gbyCols []*expression.Column, aggHints aggHintInfo, blockOffset int, nullGenerating bool) (*LogicalAggregation, error) {
agg := LogicalAggregation{
GroupByItems: expression.Column2Exprs(gbyCols),
aggHints: aggHints,
Expand All @@ -276,7 +294,7 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs
schema := expression.NewSchema(make([]*expression.Column, 0, aggLen)...)
for _, aggFunc := range aggFuncs {
var newFuncs []*aggregation.AggFuncDesc
newFuncs, schema = a.decompose(ctx, aggFunc, schema)
newFuncs, schema = a.decompose(ctx, aggFunc, schema, nullGenerating)
newAggFuncDescs = append(newAggFuncDescs, newFuncs...)
}
for _, gbyCol := range gbyCols {
Expand Down Expand Up @@ -418,6 +436,11 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
}
join.SetChildren(lChild, rChild)
join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema()))
if join.JoinType == LeftOuterJoin {
resetNotNullFlag(join.schema, lChild.Schema().Len(), join.schema.Len())
} else if join.JoinType == RightOuterJoin {
resetNotNullFlag(join.schema, 0, lChild.Schema().Len())
}
buildKeyInfo(join)
proj := a.tryToEliminateAggregation(agg)
if proj != nil {
Expand Down

0 comments on commit 3279c08

Please sign in to comment.