diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 7076d16563567..5ef266bbeefe2 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -153,6 +153,48 @@ func (s *testIntegrationSuite) TestBitColErrorMessage(c *C) { tk.MustGetErrCode("create table bit_col_t (a bit(65))", mysql.ErrTooBigDisplaywidth) } +func (s *testIntegrationSuite) TestAggPushDownLeftJoin(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists customer") + tk.MustExec("create table customer (C_CUSTKEY bigint(20) NOT NULL, C_NAME varchar(25) NOT NULL, " + + "C_ADDRESS varchar(25) NOT NULL, PRIMARY KEY (`C_CUSTKEY`) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("drop table if exists orders") + tk.MustExec("create table orders (O_ORDERKEY bigint(20) NOT NULL, O_CUSTKEY bigint(20) NOT NULL, " + + "O_TOTALPRICE decimal(15,2) NOT NULL, PRIMARY KEY (`O_ORDERKEY`) /*T![clustered_index] CLUSTERED */)") + tk.MustExec("insert into customer values (6, \"xiao zhang\", \"address1\");") + tk.MustExec("set @@tidb_opt_agg_push_down=1;") + + tk.MustQuery("select c_custkey, count(o_orderkey) as c_count from customer left outer join orders " + + "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0")) + tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from customer left outer join orders " + + "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows( + "Projection 10000.00 root test.customer.c_custkey, Column#7", + "└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey", + " └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]", + " ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey", + " │ └─TableReader 8000.00 root data:HashAgg", + " │ └─HashAgg 8000.00 cop[tikv] group by:test.orders.o_custkey, funcs:count(test.orders.o_orderkey)->Column#9", + " │ └─TableFullScan 10000.00 cop[tikv] table:orders keep order:false, stats:pseudo", + " └─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:customer keep order:false, stats:pseudo")) + + tk.MustQuery("select c_custkey, count(o_orderkey) as c_count from orders right outer join customer " + + "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0")) + tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from orders right outer join customer " + + "on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows( + "Projection 10000.00 root test.customer.c_custkey, Column#7", + "└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey", + " └─HashJoin 10000.00 root right outer join, equal:[eq(test.orders.o_custkey, test.customer.c_custkey)]", + " ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey", + " │ └─TableReader 8000.00 root data:HashAgg", + " │ └─HashAgg 8000.00 cop[tikv] group by:test.orders.o_custkey, funcs:count(test.orders.o_orderkey)->Column#9", + " │ └─TableFullScan 10000.00 cop[tikv] table:orders keep order:false, stats:pseudo", + " └─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:customer keep order:false, stats:pseudo")) +} + func (s *testIntegrationSuite) TestPushLimitDownIndexLookUpReader(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/planner/core/rule_aggregation_elimination.go b/planner/core/rule_aggregation_elimination.go index 079b2a76bace5..e533d1e83851c 100644 --- a/planner/core/rule_aggregation_elimination.go +++ b/planner/core/rule_aggregation_elimination.go @@ -125,7 +125,9 @@ func ConvertAggToProj(agg *LogicalAggregation, schema *expression.Schema) (bool, func rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) (bool, expression.Expression) { switch aggFunc.Name { case ast.AggFuncCount: - if aggFunc.Mode == aggregation.FinalMode { + if aggFunc.Mode == aggregation.FinalMode && + len(aggFunc.Args) == 1 && + mysql.HasNotNullFlag(aggFunc.Args[0].GetType().Flag) { return true, wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) } return true, rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp) diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 7cf20a0f09b88..2f37bab33c643 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -188,7 +188,8 @@ func (a *aggregationPushDownSolver) checkValidJoin(join *LogicalJoin) bool { // decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently // there are no differences between partial mode and complete mode, so we can confuse them. -func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc, schema *expression.Schema) ([]*aggregation.AggFuncDesc, *expression.Schema) { +func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc, + schema *expression.Schema, nullGenerating bool) ([]*aggregation.AggFuncDesc, *expression.Schema) { // Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case. result := []*aggregation.AggFuncDesc{aggFunc.Clone()} for _, aggFunc := range result { @@ -197,7 +198,21 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a RetType: aggFunc.RetTp, }) } - aggFunc.Args = expression.Column2Exprs(schema.Columns[schema.Len()-len(result):]) + cols := schema.Columns[schema.Len()-len(result):] + aggFunc.Args = make([]expression.Expression, 0, len(cols)) + // if the partial aggregation is on the null generating side, we have to clear the NOT NULL flag + // for the final aggregate functions' arguments + for _, col := range cols { + if nullGenerating { + arg := *col + newFieldType := *arg.RetType + newFieldType.Flag &= ^mysql.NotNullFlag + arg.RetType = &newFieldType + aggFunc.Args = append(aggFunc.Args, &arg) + } else { + aggFunc.Args = append(aggFunc.Args, col) + } + } aggFunc.Mode = aggregation.FinalMode return result, schema } @@ -220,7 +235,9 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg return child, nil } } - agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, aggHints, blockOffset) + nullGenerating := (join.JoinType == LeftOuterJoin && childIdx == 1) || + (join.JoinType == RightOuterJoin && childIdx == 0) + agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, aggHints, blockOffset, nullGenerating) if err != nil { return nil, err } @@ -266,7 +283,8 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation. // TODO: // 1. https://github.com/pingcap/tidb/issues/16355, push avg & distinct functions across join // 2. remove this method and use splitPartialAgg instead for clean code. -func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, aggHints aggHintInfo, blockOffset int) (*LogicalAggregation, error) { +func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, + gbyCols []*expression.Column, aggHints aggHintInfo, blockOffset int, nullGenerating bool) (*LogicalAggregation, error) { agg := LogicalAggregation{ GroupByItems: expression.Column2Exprs(gbyCols), aggHints: aggHints, @@ -276,7 +294,7 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs schema := expression.NewSchema(make([]*expression.Column, 0, aggLen)...) for _, aggFunc := range aggFuncs { var newFuncs []*aggregation.AggFuncDesc - newFuncs, schema = a.decompose(ctx, aggFunc, schema) + newFuncs, schema = a.decompose(ctx, aggFunc, schema, nullGenerating) newAggFuncDescs = append(newAggFuncDescs, newFuncs...) } for _, gbyCol := range gbyCols { @@ -418,6 +436,11 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e } join.SetChildren(lChild, rChild) join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema())) + if join.JoinType == LeftOuterJoin { + resetNotNullFlag(join.schema, lChild.Schema().Len(), join.schema.Len()) + } else if join.JoinType == RightOuterJoin { + resetNotNullFlag(join.schema, 0, lChild.Schema().Len()) + } buildKeyInfo(join) proj := a.tryToEliminateAggregation(agg) if proj != nil {