Skip to content

Commit

Permalink
plan: decorrelation enhancement. (#5953)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored and zz-jason committed Mar 6, 2018
1 parent c4a7242 commit 86af180
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 12 deletions.
2 changes: 1 addition & 1 deletion executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,7 @@ func (s *testSuite) TestScanControlSelection(c *C) {
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key, b int, c int, index idx_b(b))")
tk.MustExec("insert into t values (1, 1, 1), (2, 1, 1), (3, 1, 2), (4, 2, 3)")
tk.MustQuery("select (select count(1) k from t s where s.b = t1.c) from t t1").Check(testkit.Rows("3", "3", "1", "0"))
tk.MustQuery("select (select count(1) k from t s where s.b = t1.c) from t t1").Sort().Check(testkit.Rows("0", "1", "3", "3"))
}

func (s *testSuite) TestSimpleDAG(c *C) {
Expand Down
99 changes: 99 additions & 0 deletions plan/decorrelate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
package plan

import (
"math"

"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/types"
)

Expand Down Expand Up @@ -78,9 +81,52 @@ func (la *LogicalAggregation) canPullUp() bool {
return true
}

// deCorColFromEqExpr checks whether it's an equal condition of form `col = correlated col`. If so we will change the decorrelated
// column to normal column to make a new equal condition.
func (la *LogicalApply) deCorColFromEqExpr(expr expression.Expression) expression.Expression {
sf, ok := expr.(*expression.ScalarFunction)
if !ok || sf.FuncName.L != ast.EQ {
return nil
}
if col, lOk := sf.GetArgs()[0].(*expression.Column); lOk {
if corCol, rOk := sf.GetArgs()[1].(*expression.CorrelatedColumn); rOk {
ret := corCol.Decorrelate(la.Schema())
if _, ok := ret.(*expression.CorrelatedColumn); ok {
return nil
}
// We should make sure that the equal condition's left side is the join's left join key, right is the right key.
return expression.NewFunctionInternal(la.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), ret, col)
}
}
if corCol, lOk := sf.GetArgs()[0].(*expression.CorrelatedColumn); lOk {
if col, rOk := sf.GetArgs()[1].(*expression.Column); rOk {
ret := corCol.Decorrelate(la.Schema())
if _, ok := ret.(*expression.CorrelatedColumn); ok {
return nil
}
// We should make sure that the equal condition's left side is the join's left join key, right is the right key.
return expression.NewFunctionInternal(la.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), ret, col)
}
}
return nil
}

// decorrelateSolver tries to convert apply plan to join plan.
type decorrelateSolver struct{}

func (s *decorrelateSolver) aggDefaultValueMap(agg *LogicalAggregation) map[int]*expression.Constant {
defaultValueMap := make(map[int]*expression.Constant)
for i, f := range agg.AggFuncs {
switch f.Name {
case ast.AggFuncBitOr, ast.AggFuncBitXor, ast.AggFuncCount:
defaultValueMap[i] = expression.Zero.Clone().(*expression.Constant)
case ast.AggFuncBitAnd:
defaultValueMap[i] = &expression.Constant{Value: types.NewUintDatum(math.MaxUint64), RetType: types.NewFieldType(mysql.TypeLonglong)}
}
}
return defaultValueMap
}

// optimize implements logicalOptRule interface.
func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) {
if apply, ok := p.(*LogicalApply); ok {
Expand Down Expand Up @@ -153,6 +199,59 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) {
agg.collectGroupByColumns()
return agg, nil
}
// We can pull up the equal conditions below the aggregation as the join key of the apply, if only
// the equal conditions contain the correlated column of this apply.
if sel, ok := agg.children[0].(*LogicalSelection); ok && apply.JoinType == LeftOuterJoin {
var eqCondWithCorCol []*expression.ScalarFunction
// Extract the equal condition.
for i := len(sel.Conditions) - 1; i >= 0; i-- {
if expr := apply.deCorColFromEqExpr(sel.Conditions[i]); expr != nil {
eqCondWithCorCol = append(eqCondWithCorCol, expr.(*expression.ScalarFunction))
sel.Conditions = append(sel.Conditions[:i], sel.Conditions[i+1:]...)
}
}
if len(eqCondWithCorCol) > 0 {
apply.extractCorColumnsBySchema()
// There's no other correlated column.
if len(apply.corCols) == 0 {
join := &apply.LogicalJoin
join.EqualConditions = append(join.EqualConditions, eqCondWithCorCol...)
for _, eqCond := range eqCondWithCorCol {
clonedCol := eqCond.GetArgs()[1].Clone()
// If the join key is not in the aggregation's schema, add first row function.
if agg.schema.ColumnIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 {
newFunc := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false)
agg.AggFuncs = append(agg.AggFuncs, newFunc)
agg.schema.Append(clonedCol.(*expression.Column))
}
// If group by cols don't contain the join key, add it into this.
if agg.getGbyColIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 {
agg.GroupByItems = append(agg.GroupByItems, clonedCol)
}
}
agg.collectGroupByColumns()
}
// The selection may be useless, check and remove it.
if len(sel.Conditions) == 0 {
agg.SetChildren(sel.children[0])
}
defaultValueMap := s.aggDefaultValueMap(agg)
// We should use it directly, rather than building a projection.
if len(defaultValueMap) > 0 {
proj := LogicalProjection{}.init(agg.ctx)
proj.SetSchema(apply.schema)
proj.Exprs = expression.Column2Exprs(apply.schema.Columns)
for i, val := range defaultValueMap {
pos := proj.schema.ColumnIndex(agg.schema.Columns[i])
ifNullFunc := expression.NewFunctionInternal(agg.ctx, ast.Ifnull, types.NewFieldType(mysql.TypeLonglong), agg.schema.Columns[i].Clone(), val)
proj.Exprs[pos] = ifNullFunc
}
proj.SetChildren(apply)
p = proj
}
return s.optimize(p)
}
}
}
}
newChildren := make([]LogicalPlan, 0, len(p.Children()))
Expand Down
13 changes: 7 additions & 6 deletions plan/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,19 +569,16 @@ func (s *testPlanSuite) TestSubquery(c *C) {
best: "Apply{DataScan(t)->DataScan(s)->Sel([eq(s.a, test.t.a)])->Aggr(count(s.b))}->Projection",
},
{
// Theta-join with agg cannot decorrelate.
sql: "select (select count(s.b) k from t s where s.a = t.a having k != 0) from t",
best: "Apply{DataScan(t)->DataScan(s)->Sel([eq(s.a, test.t.a)])->Aggr(count(s.b))}->Projection->Projection",
best: "Join{DataScan(t)->DataScan(s)->Aggr(count(s.b),firstrow(s.a))}(test.t.a,s.a)->Projection->Projection->Projection",
},
{
// Relation without keys cannot decorrelate.
sql: "select (select count(s.b) k from t s where s.a = t1.a) from t t1, t t2",
best: "Apply{Join{DataScan(t1)->DataScan(t2)}->DataScan(s)->Sel([eq(s.a, t1.a)])->Aggr(count(s.b))}->Projection->Projection",
best: "Join{Join{DataScan(t1)->DataScan(t2)}->DataScan(s)->Aggr(count(s.b),firstrow(s.a))}(t1.a,s.a)->Projection->Projection->Projection",
},
{
// Aggregate function like count(1) cannot decorrelate.
sql: "select (select count(1) k from t s where s.a = t.a having k != 0) from t",
best: "Apply{DataScan(t)->DataScan(s)->Sel([eq(s.a, test.t.a)])->Aggr(count(1))}->Projection->Projection",
best: "Join{DataScan(t)->DataScan(s)->Aggr(count(1),firstrow(s.a))}(test.t.a,s.a)->Projection->Projection->Projection",
},
{
sql: "select a from t where a in (select a from t s group by t.b)",
Expand Down Expand Up @@ -611,6 +608,10 @@ func (s *testPlanSuite) TestSubquery(c *C) {
sql: "select * from t where exists (select s.a from t s where s.c in (select c from t as k where k.d = s.d) having sum(s.a) = t.a )",
best: "Join{DataScan(t)->Join{DataScan(s)->DataScan(k)}(s.d,k.d)(s.c,k.c)->Aggr(sum(s.a))->Projection}->Projection",
},
{
sql: "select t1.b from t t1 where t1.b = (select max(t2.a) from t t2 where t1.b=t2.b)",
best: "Join{DataScan(t1)->DataScan(t2)->Aggr(max(t2.a),firstrow(t2.b))}(t1.b,t2.b)->Projection->Sel([eq(t1.b, max(t2.a))])->Projection",
},
}

for _, ca := range tests {
Expand Down
4 changes: 2 additions & 2 deletions plan/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ var AllowCartesianProduct = true

const (
flagPrunColumns uint64 = 1 << iota
flagMaxMinEliminate
flagEliminateProjection
flagBuildKeyInfo
flagDecorrelate
flagMaxMinEliminate
flagPredicatePushDown
flagAggregationOptimize
flagPushDownTopN
)

var optRuleList = []logicalOptRule{
&columnPruner{},
&maxMinEliminator{},
&projectionEliminater{},
&buildKeySolver{},
&decorrelateSolver{},
&maxMinEliminator{},
&ppdSolver{},
&aggregationOptimizer{},
&pushDownTopNOptimizer{},
Expand Down
6 changes: 3 additions & 3 deletions plan/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,11 @@ func (s *testPlanSuite) TestDAGPlanBuilderSubquery(c *C) {
},
{
sql: "select (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t",
best: "Apply{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->Sel([eq(s.a, test.t.a)])->TableReader(Table(t))}(s.a,t1.a)->StreamAgg}->Projection",
best: "LeftHashJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(s.a,t1.a)->StreamAgg}(test.t.a,s.a)->Projection->Projection",
},
{
sql: "select (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t order by t.a",
best: "Apply{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->Sel([eq(s.a, test.t.a)])->TableReader(Table(t))}(s.a,t1.a)->StreamAgg}->Projection",
best: "LeftHashJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(s.a,t1.a)->StreamAgg}(test.t.a,s.a)->Projection->Sort->Projection",
},
}
for _, tt := range tests {
Expand Down Expand Up @@ -785,7 +785,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) {
},
{
sql: "select (select count(1) k from t s where s.a = t.a having k != 0) from t",
best: "Apply{TableReader(Table(t))->TableReader(Table(t))->Sel([eq(s.a, test.t.a)])->StreamAgg->Sel([ne(k, 0)])}->Projection",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t)->StreamAgg)->StreamAgg->Sel([ne(k, 0)])}(test.t.a,s.a)->Projection->Projection",
},
// Test stream agg with multi group by columns.
{
Expand Down

0 comments on commit 86af180

Please sign in to comment.