diff --git a/plan/decorrelate.go b/plan/decorrelate.go index 8c67344bf5c89..65a092cc5f8bd 100644 --- a/plan/decorrelate.go +++ b/plan/decorrelate.go @@ -202,15 +202,21 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { // We can pull up the equal conditions below the aggregation as the join key of the apply, if only // the equal conditions contain the correlated column of this apply. if sel, ok := agg.children[0].(*LogicalSelection); ok && apply.JoinType == LeftOuterJoin { - var eqCondWithCorCol []*expression.ScalarFunction + var ( + eqCondWithCorCol []*expression.ScalarFunction + remainedExpr []expression.Expression + ) // Extract the equal condition. - for i := len(sel.Conditions) - 1; i >= 0; i-- { - if expr := apply.deCorColFromEqExpr(sel.Conditions[i]); expr != nil { + for _, cond := range sel.Conditions { + if expr := apply.deCorColFromEqExpr(cond); expr != nil { eqCondWithCorCol = append(eqCondWithCorCol, expr.(*expression.ScalarFunction)) - sel.Conditions = append(sel.Conditions[:i], sel.Conditions[i+1:]...) + } else { + remainedExpr = append(remainedExpr, cond) } } if len(eqCondWithCorCol) > 0 { + originalExpr := sel.Conditions + sel.Conditions = remainedExpr apply.extractCorColumnsBySchema() // There's no other correlated column. if len(apply.corCols) == 0 { @@ -230,26 +236,28 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { } } agg.collectGroupByColumns() - } - // The selection may be useless, check and remove it. - if len(sel.Conditions) == 0 { - agg.SetChildren(sel.children[0]) - } - defaultValueMap := s.aggDefaultValueMap(agg) - // We should use it directly, rather than building a projection. - if len(defaultValueMap) > 0 { - proj := LogicalProjection{}.init(agg.ctx) - proj.SetSchema(apply.schema) - proj.Exprs = expression.Column2Exprs(apply.schema.Columns) - for i, val := range defaultValueMap { - pos := proj.schema.ColumnIndex(agg.schema.Columns[i]) - ifNullFunc := expression.NewFunctionInternal(agg.ctx, ast.Ifnull, types.NewFieldType(mysql.TypeLonglong), agg.schema.Columns[i].Clone(), val) - proj.Exprs[pos] = ifNullFunc + // The selection may be useless, check and remove it. + if len(sel.Conditions) == 0 { + agg.SetChildren(sel.children[0]) } - proj.SetChildren(apply) - p = proj + defaultValueMap := s.aggDefaultValueMap(agg) + // We should use it directly, rather than building a projection. + if len(defaultValueMap) > 0 { + proj := LogicalProjection{}.init(agg.ctx) + proj.SetSchema(apply.schema) + proj.Exprs = expression.Column2Exprs(apply.schema.Columns) + for i, val := range defaultValueMap { + pos := proj.schema.ColumnIndex(agg.schema.Columns[i]) + ifNullFunc := expression.NewFunctionInternal(agg.ctx, ast.Ifnull, types.NewFieldType(mysql.TypeLonglong), agg.schema.Columns[i], val) + proj.Exprs[pos] = ifNullFunc + } + proj.SetChildren(apply) + p = proj + } + return s.optimize(p) } - return s.optimize(p) + sel.Conditions = originalExpr + apply.extractCorColumnsBySchema() } } } diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index ae30579b8a390..ead573fc39175 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -507,6 +507,10 @@ func (s *testPlanSuite) TestSubquery(c *C) { sql: "select t1.b from t t1 where t1.b = (select max(t2.a) from t t2 where t1.b=t2.b)", best: "Join{DataScan(t1)->DataScan(t2)->Aggr(max(t2.a),firstrow(t2.b))}(t1.b,t2.b)->Projection->Sel([eq(t1.b, max(t2.a))])->Projection", }, + { + sql: "select t1.b from t t1 where t1.b = (select avg(t2.a) from t t2 where t1.g=t2.g and (t1.b = 4 or t2.b = 2))", + best: "Apply{DataScan(t1)->DataScan(t2)->Sel([eq(t1.g, t2.g) or(eq(t1.b, 4), eq(t2.b, 2))])->Aggr(avg(t2.a))}->Projection->Sel([eq(cast(t1.b), avg(t2.a))])->Projection", + }, } for _, ca := range tests {