diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 6e8ae1b3de67f..e54a6bd6ed05f 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -393,6 +393,19 @@ Projection_5 8000.00 root test.ta.a └─TableReader_9 10000.00 root data:TableScan_8 └─TableScan_8 10000.00 cop table:ta, range:[-inf,+inf], keep order:false, stats:pseudo rollback; +drop table if exists t1, t2; +create table t1(a int, b int, c int, primary key(a, b)); +create table t2(a int, b int, c int, primary key(a)); +explain select t1.a, t1.b from t1 left outer join t2 on t1.a = t2.a; +id count task operator info +TableReader_7 10000.00 root data:TableScan_6 +└─TableScan_6 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select distinct t1.a, t1.b from t1 left outer join t2 on t1.a = t2.a; +id count task operator info +StreamAgg_19 8000.00 root group by:col_2, col_3, funcs:firstrow(col_0), firstrow(col_1) +└─IndexReader_20 8000.00 root index:StreamAgg_10 + └─StreamAgg_10 8000.00 cop group by:test.t1.a, test.t1.b, funcs:firstrow(test.t1.a), firstrow(test.t1.b) + └─IndexScan_18 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:true, stats:pseudo drop table if exists t; create table t(a int, nb int not null, nc int not null); explain select ifnull(a, 0) from t; diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index f71291193664b..5ae5a1764c906 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -82,6 +82,13 @@ insert tb values ('1'); explain select * from ta where a = 1; rollback; +# outer join elimination +drop table if exists t1, t2; +create table t1(a int, b int, c int, primary key(a, b)); +create table t2(a int, b int, c int, primary key(a)); +explain select t1.a, t1.b from t1 left outer join t2 on t1.a = t2.a; +explain select distinct t1.a, t1.b from t1 left outer join t2 on t1.a = t2.a; + # https://github.com/pingcap/tidb/issues/7918 drop table if exists t; create table t(a int, nb int not null, nc int not null); diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index b340cbcfb371a..f35ba0a2e3e5b 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -309,9 +309,13 @@ func (b *PlanBuilder) buildJoin(joinNode *ast.Join) (LogicalPlan, error) { // Set join type. switch joinNode.Tp { case ast.LeftJoin: + // left outer join need to be checked elimination + b.optFlag = b.optFlag | flagEliminateOuterJoin joinPlan.JoinType = LeftOuterJoin resetNotNullFlag(joinPlan.schema, leftPlan.Schema().Len(), joinPlan.schema.Len()) case ast.RightJoin: + // right outer join need to be checked elimination + b.optFlag = b.optFlag | flagEliminateOuterJoin joinPlan.JoinType = RightOuterJoin resetNotNullFlag(joinPlan.schema, 0, leftPlan.Schema().Len()) default: diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 9eaefd0e3096f..e65d9a8811348 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -2026,3 +2026,67 @@ func (s *testPlanSuite) TestNameResolver(c *C) { } } } + +func (s *testPlanSuite) TestOuterJoinEliminator(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + sql string + best string + }{ + // Test left outer join + distinct + { + sql: "select distinct t1.a, t1.b from t t1 left outer join t t2 on t1.b = t2.b", + best: "DataScan(t1)->Aggr(firstrow(t1.a),firstrow(t1.b))", + }, + // Test right outer join + distinct + { + sql: "select distinct t2.a, t2.b from t t1 right outer join t t2 on t1.b = t2.b", + best: "DataScan(t2)->Aggr(firstrow(t2.a),firstrow(t2.b))", + }, + // Test duplicate agnostic agg functions on join + { + sql: "select max(t1.a), min(t1.b) from t t1 left join t t2 on t1.b = t2.b", + best: "DataScan(t1)->Aggr(max(t1.a),min(t1.b))->Projection", + }, + { + sql: "select sum(distinct t1.a) from t t1 left join t t2 on t1.a = t2.a and t1.b = t2.b", + best: "DataScan(t1)->Aggr(sum(t1.a))->Projection", + }, + { + sql: "select count(distinct t1.a, t1.b) from t t1 left join t t2 on t1.b = t2.b", + best: "DataScan(t1)->Aggr(count(t1.a, t1.b))->Projection", + }, + // Test left outer join + { + sql: "select t1.b from t t1 left outer join t t2 on t1.a = t2.a", + best: "DataScan(t1)->Projection", + }, + // Test right outer join + { + sql: "select t2.b from t t1 right outer join t t2 on t1.a = t2.a", + best: "DataScan(t2)->Projection", + }, + // For complex join query + { + sql: "select max(t3.b) from (t t1 left join t t2 on t1.a = t2.a) right join t t3 on t1.b = t3.b", + best: "DataScan(t3)->TopN([t3.b true],0,1)->Aggr(max(t3.b))->Projection", + }, + } + + for i, tt := range tests { + comment := Commentf("case:%v sql:%s", i, tt.sql) + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + Preprocess(s.ctx, stmt, s.is, false) + builder := &PlanBuilder{ + ctx: mockContext(), + is: s.is, + colMapper: make(map[*ast.ColumnNameExpr]int), + } + p, err := builder.Build(stmt) + c.Assert(err, IsNil) + p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + c.Assert(err, IsNil) + c.Assert(ToString(p), Equals, tt.best, comment) + } +} diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 305aa57eddb35..3065ebb1891d7 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -39,6 +39,7 @@ const ( flagEliminateProjection flagMaxMinEliminate flagPredicatePushDown + flagEliminateOuterJoin flagPartitionProcessor flagPushDownAgg flagPushDownTopN @@ -52,6 +53,7 @@ var optRuleList = []logicalOptRule{ &projectionEliminater{}, &maxMinEliminator{}, &ppdSolver{}, + &outerJoinEliminator{}, &partitionProcessor{}, &aggregationPushDownSolver{}, &pushDownTopNOptimizer{}, diff --git a/planner/core/rule_join_elimination.go b/planner/core/rule_join_elimination.go new file mode 100644 index 0000000000000..f6dd30834a701 --- /dev/null +++ b/planner/core/rule_join_elimination.go @@ -0,0 +1,192 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb/expression" +) + +type outerJoinEliminator struct { +} + +// tryToEliminateOuterJoin will eliminate outer join plan base on the following rules +// 1. outer join elimination: For example left outer join, if the parent only use the +// columns from left table and the join key of right table(the inner table) is a unique +// key of the right table. the left outer join can be eliminated. +// 2. outer join elimination with duplicate agnostic aggregate functions: For example left outer join. +// If the parent only use the columns from left table with 'distinct' label. The left outer join can +// be eliminated. +func (o *outerJoinEliminator) tryToEliminateOuterJoin(p *LogicalJoin, aggCols []*expression.Column, parentSchema *expression.Schema) LogicalPlan { + var innerChildIdx int + switch p.JoinType { + case LeftOuterJoin: + innerChildIdx = 1 + case RightOuterJoin: + innerChildIdx = 0 + default: + return p + } + + outerPlan := p.children[1^innerChildIdx] + innerPlan := p.children[innerChildIdx] + // outer join elimination with duplicate agnostic aggregate functions + if o.isAggColsAllFromOuterTable(outerPlan, aggCols) { + return outerPlan + } + // outer join elimination without duplicate agnostic aggregate functions + if !o.isParentColsAllFromOuterTable(outerPlan, parentSchema) { + return p + } + innerJoinKeys := o.extractInnerJoinKeys(p, innerChildIdx) + if o.isInnerJoinKeysContainUniqueKey(innerPlan, innerJoinKeys) { + return outerPlan + } + if o.isInnerJoinKeysContainIndex(innerPlan, innerJoinKeys) { + return outerPlan + } + return p +} + +// extract join keys as a schema for inner child of a outer join +func (o *outerJoinEliminator) extractInnerJoinKeys(join *LogicalJoin, innerChildIdx int) *expression.Schema { + var joinKeys []*expression.Column + for _, eqCond := range join.EqualConditions { + joinKeys = append(joinKeys, eqCond.GetArgs()[innerChildIdx].(*expression.Column)) + } + return expression.NewSchema(joinKeys...) +} + +func (o *outerJoinEliminator) isAggColsAllFromOuterTable(outerPlan LogicalPlan, aggCols []*expression.Column) bool { + if len(aggCols) == 0 { + return false + } + for _, col := range aggCols { + columnName := &ast.ColumnName{Schema: col.DBName, Table: col.TblName, Name: col.ColName} + if c, _ := outerPlan.Schema().FindColumn(columnName); c == nil { + return false + } + } + return true +} + +// check whether schema cols of join's parent plan are all from outer join table +func (o *outerJoinEliminator) isParentColsAllFromOuterTable(outerPlan LogicalPlan, parentSchema *expression.Schema) bool { + if parentSchema == nil { + return false + } + for _, col := range parentSchema.Columns { + columnName := &ast.ColumnName{Schema: col.DBName, Table: col.TblName, Name: col.ColName} + if c, _ := outerPlan.Schema().FindColumn(columnName); c == nil { + return false + } + } + return true +} + +// check whether one of unique keys sets is contained by inner join keys +func (o *outerJoinEliminator) isInnerJoinKeysContainUniqueKey(innerPlan LogicalPlan, joinKeys *expression.Schema) bool { + for _, keyInfo := range innerPlan.Schema().Keys { + joinKeysContainKeyInfo := true + for _, col := range keyInfo { + columnName := &ast.ColumnName{Schema: col.DBName, Table: col.TblName, Name: col.ColName} + if c, _ := joinKeys.FindColumn(columnName); c == nil { + joinKeysContainKeyInfo = false + break + } + } + if joinKeysContainKeyInfo { + return true + } + } + return false +} + +// check whether one of index sets is contained by inner join index +func (o *outerJoinEliminator) isInnerJoinKeysContainIndex(innerPlan LogicalPlan, joinKeys *expression.Schema) bool { + ds, ok := innerPlan.(*DataSource) + if !ok { + return false + } + for _, path := range ds.possibleAccessPaths { + if path.isTablePath { + continue + } + idx := path.index + if !idx.Unique { + continue + } + joinKeysContainIndex := true + for _, idxCol := range idx.Columns { + columnName := &ast.ColumnName{Schema: ds.DBName, Table: ds.tableInfo.Name, Name: idxCol.Name} + if c, _ := joinKeys.FindColumn(columnName); c == nil { + joinKeysContainIndex = false + break + } + } + if joinKeysContainIndex { + return true + } + } + return false +} + +// Check whether a LogicalPlan is a LogicalAggregation and its all aggregate functions is duplicate agnostic. +// Also, check all the args are expression.Column. +func (o *outerJoinEliminator) isDuplicateAgnosticAgg(p LogicalPlan) (_ bool, cols []*expression.Column) { + agg, ok := p.(*LogicalAggregation) + if !ok { + return false, nil + } + cols = agg.groupByCols + for _, aggDesc := range agg.AggFuncs { + if !aggDesc.HasDistinct && + aggDesc.Name != ast.AggFuncFirstRow && + aggDesc.Name != ast.AggFuncMax && + aggDesc.Name != ast.AggFuncMin { + return false, nil + } + for _, expr := range aggDesc.Args { + if col, ok := expr.(*expression.Column); ok { + cols = append(cols, col) + } else { + return false, nil + } + } + } + return true, cols +} + +func (o *outerJoinEliminator) doOptimize(p LogicalPlan, aggCols []*expression.Column, parentSchema *expression.Schema) LogicalPlan { + // check the duplicate agnostic aggregate functions + if ok, newCols := o.isDuplicateAgnosticAgg(p); ok { + aggCols = newCols + } + + newChildren := make([]LogicalPlan, 0, len(p.Children())) + for _, child := range p.Children() { + newChild := o.doOptimize(child, aggCols, p.Schema()) + newChildren = append(newChildren, newChild) + } + p.SetChildren(newChildren...) + join, isJoin := p.(*LogicalJoin) + if !isJoin { + return p + } + return o.tryToEliminateOuterJoin(join, aggCols, parentSchema) +} + +func (o *outerJoinEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { + return o.doOptimize(p, nil, nil), nil +}