From c79e12ee8103da931098fa80417650a1a644f4a2 Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Wed, 16 Jan 2019 12:53:24 +0800 Subject: [PATCH 1/4] planner: check null and empty for `!= any(subq)` and `= all(subq)` --- cmd/explaintest/r/select.result | 22 ++++++ cmd/explaintest/t/select.test | 6 ++ planner/core/expression_rewriter.go | 56 ++++++++------ planner/core/expression_rewriter_test.go | 94 ++++++++++++++++++++++++ 4 files changed, 154 insertions(+), 24 deletions(-) diff --git a/cmd/explaintest/r/select.result b/cmd/explaintest/r/select.result index 069f8ad4a435d..7ac706e26eee7 100644 --- a/cmd/explaintest/r/select.result +++ b/cmd/explaintest/r/select.result @@ -358,3 +358,25 @@ Union_7 20000.00 root │ └─TableScan_8 10000.00 cop table:th, partition:p1, range:[-inf,+inf], keep order:false, stats:pseudo └─TableReader_11 10000.00 root data:TableScan_10 └─TableScan_10 10000.00 cop table:th, partition:p2, range:[-inf,+inf], keep order:false, stats:pseudo +drop table if exists t; +create table t(a int, b int); +explain select a != any (select a from t t2) from t t1; +id count task operator info +Projection_9 10000.00 root and(or(or(gt(col_count, 1), ne(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 0)), and(ne(agg_col_cnt, 0), if(isnull(t1.a), NULL, 1))) +└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 + ├─TableReader_13 10000.00 root data:TableScan_12 + │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(col_3) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)), isnull(t2.a) + └─TableReader_24 10000.00 root data:TableScan_23 + └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +explain select a = all (select a from t t2) from t t1; +id count task operator info +Projection_9 10000.00 root or(and(and(le(col_count, 1), eq(t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 1)), or(eq(agg_col_cnt, 0), if(isnull(t1.a), NULL, 0))) +└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 + ├─TableReader_13 10000.00 root data:TableScan_12 + │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(col_3) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)), isnull(t2.a) + └─TableReader_24 10000.00 root data:TableScan_23 + └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/t/select.test b/cmd/explaintest/t/select.test index 5070a52061206..80143354ff88e 100644 --- a/cmd/explaintest/t/select.test +++ b/cmd/explaintest/t/select.test @@ -175,3 +175,9 @@ insert into th values (-1,-1),(-2,-2),(-3,-3),(-4,-4),(-5,-5),(-6,-6),(-7,-7),(- desc select * from th where a=-2; desc select * from th; desc select * from th partition (p2,p1); + +# test != any(subq) and = all(subq) +drop table if exists t; +create table t(a int, b int); +explain select a != any (select a from t t2) from t t1; +explain select a = all (select a from t t2) from t t1; diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 5adc459cbca8f..4ffc2c36038df 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -428,14 +428,15 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin} cond := expression.NewFunctionInternal(er.ctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, all) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all) } // buildQuantifierPlan adds extra condition for any / all subquery. -func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, rexpr expression.Expression, all bool) { - funcIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) +func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) { + innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) + outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{funcIsNull}, false) + funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) colSum := &expression.Column{ ColName: model.NewCIStr("agg_col_sum"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), @@ -443,29 +444,36 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum) plan4Agg.schema.Append(colSum) + innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) + + funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{innerIsNull}, false) + colCount := &expression.Column{ + ColName: model.NewCIStr("agg_col_cnt"), + UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: funcCount.RetTp, + } + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) + plan4Agg.schema.Append(colCount) if all { - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{funcIsNull}, false) - colCount := &expression.Column{ - ColName: model.NewCIStr("agg_col_cnt"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: funcCount.RetTp, - } - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) - plan4Agg.schema.Append(colCount) // All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) = 0, true, null). - hasNotNull := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNotNull, expression.One, expression.Null) - cond = expression.ComposeCNFCondition(er.ctx, cond, nullChecker) - // If the set is empty, it should always return true. - checkEmpty := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, checkEmpty) + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return true. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } else { - // For "any" expression, if the record set has null and the cond return false, the result should be NULL. - hasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) - nullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), hasNull, expression.Null, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, nullChecker) + // For "any" expression, if the subquery has null and the cond return false, the result should be NULL. + innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker) + // If the subquery is empty, it should always return false. + emptyChecker := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + // If outer key is null, and subquery is not empty, it should return null. + outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. @@ -518,7 +526,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np gtFunc := expression.NewFunctionInternal(er.ctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.One) neCond := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) cond := expression.ComposeDNFCondition(er.ctx, gtFunc, neCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, false) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false) } // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to @@ -544,7 +552,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np leFunc := expression.NewFunctionInternal(er.ctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.One) eqCond := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) cond := expression.ComposeCNFCondition(er.ctx, leFunc, eqCond) - er.buildQuantifierPlan(plan4Agg, cond, rexpr, true) + er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true) } func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (ast.Node, bool) { diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index 6089ac43f20ea..805c81e7d3449 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -129,3 +129,97 @@ func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) { tk.MustExec("update t1 set c = c + default(c)") tk.MustQuery("select c from t1").Check(testkit.Rows("11")) } + +func (s *testExpressionRewriterSuite) TestCompareSubquery(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists s") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("create table s(a int, b int)") + tk.MustExec("insert into t values(1, null), (2, null)") + + // Test empty checker. + tk.MustQuery("select a != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select b != any (select a from s) from t").Check(testkit.Rows( + "0", + "0", + )) + tk.MustQuery("select a = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select b = all (select a from s) from t").Check(testkit.Rows( + "1", + "1", + )) + tk.MustQuery("select * from t where a != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where b != any (select a from s)").Check(testkit.Rows()) + tk.MustQuery("select * from t where a = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + tk.MustQuery("select * from t where b = all (select a from s)").Check(testkit.Rows( + "1 ", + "2 ", + )) + // Test outer null checker. + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where a = 2") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + // Test inner null checker. + tk.MustExec("insert into t values(null, 1)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) + + tk.MustExec("delete from t where b = 1") + tk.MustExec("insert into t values(null, 2)") + tk.MustQuery("select b != any (select a from t t2) from t t1").Check(testkit.Rows( + "", + "1", + )) + tk.MustQuery("select b = all (select a from t t2) from t t1").Check(testkit.Rows( + "", + "0", + )) + tk.MustQuery("select * from t t1 where b != any (select a from t t2)").Check(testkit.Rows( + " 2", + )) + tk.MustQuery("select * from t t1 where b = all (select a from t t2)").Check(testkit.Rows()) +} From 93a17241400cd789a2ce7e48c36f148a44a13718 Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Wed, 16 Jan 2019 17:06:26 +0800 Subject: [PATCH 2/4] adjust comments --- planner/core/expression_rewriter.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 4ffc2c36038df..25583ccca11ed 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -457,7 +457,7 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, if all { // All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it - // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) = 0, true, null). + // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true). innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One) cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker) // If the subquery is empty, it should always return true. @@ -466,7 +466,8 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero) cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) } else { - // For "any" expression, if the subquery has null and the cond return false, the result should be NULL. + // For "any" expression, if the subquery has null and the cond returns false, the result should be NULL. + // Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)` innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero) cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker) // If the subquery is empty, it should always return false. From 5f62d5171461b5e537343e1922facdb093a1ae53 Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Wed, 16 Jan 2019 22:42:34 +0800 Subject: [PATCH 3/4] change count(isnull) to count(1) --- planner/core/expression_rewriter.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 435d1c9f96359..09f66f4efd434 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -447,7 +447,8 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, plan4Agg.schema.Append(colSum) innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{innerIsNull}, false) + // Build `count(1)` aggregation to check if subquery is empty. + funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) colCount := &expression.Column{ ColName: model.NewCIStr("agg_col_cnt"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), From 2496e1ce714d5873a60b59b448352de187764746 Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Wed, 16 Jan 2019 22:49:13 +0800 Subject: [PATCH 4/4] update explain result --- cmd/explaintest/r/select.result | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/explaintest/r/select.result b/cmd/explaintest/r/select.result index 7ac706e26eee7..03a5ceb0a6289 100644 --- a/cmd/explaintest/r/select.result +++ b/cmd/explaintest/r/select.result @@ -366,8 +366,8 @@ Projection_9 10000.00 root and(or(or(gt(col_count, 1), ne(t1.a, col_firstrow)), └─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 ├─TableReader_13 10000.00 root data:TableScan_12 │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(col_3) - └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)), isnull(t2.a) + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)) └─TableReader_24 10000.00 root data:TableScan_23 └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo explain select a = all (select a from t t2) from t t1; @@ -376,7 +376,7 @@ Projection_9 10000.00 root or(and(and(le(col_count, 1), eq(t1.a, col_firstrow)), └─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 ├─TableReader_13 10000.00 root data:TableScan_12 │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(col_3) - └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)), isnull(t2.a) + └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) + └─Projection_27 10000.00 root t2.a, t2.a, cast(isnull(t2.a)) └─TableReader_24 10000.00 root data:TableScan_23 └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo