From 6e3d2ecb48a931e835ad4826013e915051c6ebc4 Mon Sep 17 00:00:00 2001 From: Zhuhe Fang Date: Fri, 13 Nov 2020 15:42:19 +0800 Subject: [PATCH] expression: avoid unnecessary warnings/errors when folding constants in shortcut-able expressions (#19797) --- expression/function_traits.go | 10 +++++++--- expression/integration_test.go | 9 +++++++++ planner/core/expression_rewriter.go | 10 ++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/expression/function_traits.go b/expression/function_traits.go index 12631d357e6ff..8ba1ea7d0b91e 100644 --- a/expression/function_traits.go +++ b/expression/function_traits.go @@ -60,9 +60,13 @@ var DisableFoldFunctions = map[string]struct{}{ // otherwise, the child functions do not fold constant. // Note: the function itself should fold constant. var TryFoldFunctions = map[string]struct{}{ - ast.If: {}, - ast.Ifnull: {}, - ast.Case: {}, + ast.If: {}, + ast.Ifnull: {}, + ast.Case: {}, + ast.LogicAnd: {}, + ast.LogicOr: {}, + ast.Coalesce: {}, + ast.Interval: {}, } // IllegalFunctions4GeneratedColumns stores functions that is illegal for generated columns. diff --git a/expression/integration_test.go b/expression/integration_test.go index 63159b0f508a3..02e18b492e4d8 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2899,6 +2899,15 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { tk.MustQuery("select 1 or b/0 from t") tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select 1 or 1/0") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select 0 and 1/0") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select COALESCE(1, 1/0)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select interval(1,0,1,2,1/0)") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("select case 2.0 when 2.0 then 3.0 when 3.0 then 2.0 end").Check(testkit.Rows("3.0")) tk.MustQuery("select case 2.0 when 3.0 then 2.0 when 4.0 then 3.0 else 5.0 end").Check(testkit.Rows("5.0")) tk.MustQuery("select case cast('2011-01-01' as date) when cast('2011-01-01' as date) then cast('2011-02-02' as date) end").Check(testkit.Rows("2011-02-02")) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 2ca683182b78b..a8e04c2999bd2 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -410,6 +410,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { er.ctxStackAppend(er.schema.Columns[index], er.names[index]) return inNode, true case *ast.FuncCallExpr: + er.asScalar = true if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok { er.disableFoldCounter++ } @@ -417,12 +418,18 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { er.tryFoldCounter++ } case *ast.CaseExpr: + er.asScalar = true if _, ok := expression.DisableFoldFunctions["case"]; ok { er.disableFoldCounter++ } if _, ok := expression.TryFoldFunctions["case"]; ok { er.tryFoldCounter++ } + case *ast.BinaryOperationExpr: + er.asScalar = true + if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr { + er.tryFoldCounter++ + } case *ast.SetCollationExpr: // Do nothing default: @@ -1021,6 +1028,9 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok case *ast.UnaryOperationExpr: er.unaryOpToExpression(v) case *ast.BinaryOperationExpr: + if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr { + er.tryFoldCounter-- + } er.binaryOpToExpression(v) case *ast.BetweenExpr: er.betweenToExpression(v)