diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index d6b6a680498cf..a241b1133a27a 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -42,13 +42,11 @@ func evalAstExpr(ctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error) if val, ok := expr.(*driver.ValueExpr); ok { return val.Datum, nil } - b := &PlanBuilder{ - ctx: ctx, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + var is infoschema.InfoSchema if ctx.GetSessionVars().TxnCtx.InfoSchema != nil { - b.is = ctx.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema) + is = ctx.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema) } + b := NewPlanBuilder(ctx, is) fakePlan := LogicalTableDual{}.Init(ctx) newExpr, _, err := b.rewrite(expr, fakePlan, nil, true) if err != nil { diff --git a/planner/core/indexmerge_test.go b/planner/core/indexmerge_test.go index 2ff5f09c91ef0..07ed20fb2ac69 100644 --- a/planner/core/indexmerge_test.go +++ b/planner/core/indexmerge_test.go @@ -16,7 +16,6 @@ package core import ( . "github.com/pingcap/check" "github.com/pingcap/parser" - "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/sessionctx" @@ -108,11 +107,7 @@ func (s *testIndexMergeSuite) TestIndexMergePathGenerateion(c *C) { stmt, err := s.ParseOneStmt(tc.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(MockContext(), s.is) p, err := builder.Build(stmt) if err != nil { c.Assert(err.Error(), Equals, tc.idxMergeDigest, comment) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index fd502d0b5ad20..d22ee24afa1c0 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2293,7 +2293,7 @@ func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { if err != nil { return nil, err } - if txn.Valid() && !txn.IsReadOnly() { + if txn.Valid() && !txn.IsReadOnly() && !isMemDB { us := LogicalUnionScan{}.Init(b.ctx) us.SetChildren(ds) result = us diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 25d94c1cf516b..5c1586e2548e9 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -21,7 +21,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser" - "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -1714,11 +1713,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - colMapper: make(map[*ast.ColumnNameExpr]int), - ctx: MockContext(), - is: s.is, - } + builder := NewPlanBuilder(MockContext(), s.is) builder.ctx.GetSessionVars().HashJoinConcurrency = 1 _, err = builder.Build(stmt) c.Assert(err, IsNil, comment) @@ -1832,11 +1827,7 @@ func (s *testPlanSuite) TestUnion(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(MockContext(), s.is) plan, err := builder.Build(stmt) if tt.err { c.Assert(err, NotNil) @@ -1964,11 +1955,7 @@ func (s *testPlanSuite) TestTopNPushDown(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(MockContext(), s.is) p, err := builder.Build(stmt) c.Assert(err, IsNil) p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) @@ -2075,11 +2062,7 @@ func (s *testPlanSuite) TestOuterJoinEliminator(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(MockContext(), s.is) p, err := builder.Build(stmt) c.Assert(err, IsNil) p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) @@ -2106,11 +2089,7 @@ func (s *testPlanSuite) TestSelectView(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(MockContext(), s.is) p, err := builder.Build(stmt) c.Assert(err, IsNil) p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) @@ -2333,11 +2312,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(MockContext(), s.is) p, err := builder.Build(stmt) if err != nil { c.Assert(err.Error(), Equals, tt.result, comment) @@ -2418,11 +2393,7 @@ func (s *testPlanSuite) TestSkylinePruning(c *C) { stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(MockContext(), s.is) p, err := builder.Build(stmt) if err != nil { c.Assert(err.Error(), Equals, tt.result, comment) diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index d41dadde11710..ab7d09117fe19 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -74,11 +74,7 @@ type logicalOptRule interface { func BuildLogicalPlan(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (Plan, error) { ctx.GetSessionVars().PlanID = 0 ctx.GetSessionVars().PlanColumnID = 0 - builder := &PlanBuilder{ - ctx: ctx, - is: is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } + builder := NewPlanBuilder(ctx, is) p, err := builder.Build(node) if err != nil { return nil, err diff --git a/planner/core/planbuilder_test.go b/planner/core/planbuilder_test.go index 3c1487d1a3637..282e95a2cd9ff 100644 --- a/planner/core/planbuilder_test.go +++ b/planner/core/planbuilder_test.go @@ -94,9 +94,7 @@ func (s *testPlanBuilderSuite) TestGetPathByIndexName(c *C) { } func (s *testPlanBuilderSuite) TestRewriterPool(c *C) { - builder := &PlanBuilder{ - ctx: MockContext(), - } + builder := NewPlanBuilder(MockContext(), nil) // Make sure PlanBuilder.getExpressionRewriter() provides clean rewriter from pool. // First, pick one rewriter from the pool and make it dirty. @@ -149,7 +147,7 @@ func (s *testPlanBuilderSuite) TestDisableFold(c *C) { stmt := st.(*ast.SelectStmt) expr := stmt.Fields.Fields[0].Expr - builder := &PlanBuilder{ctx: ctx} + builder := NewPlanBuilder(ctx, nil) builder.rewriterCounter++ rewriter := builder.getExpressionRewriter(nil) c.Assert(rewriter, NotNil) diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 4b04412e5c02a..f2c3ebc59db63 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -45,12 +45,11 @@ func (s *partitionProcessor) optimize(lp LogicalPlan) (LogicalPlan, error) { func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan) (LogicalPlan, error) { // Assert there will not be sel -> sel in the ast. - switch lp.(type) { + switch p := lp.(type) { case *DataSource: - return s.prune(lp.(*DataSource)) + return s.prune(p) case *LogicalUnionScan: - us := lp.(*LogicalUnionScan) - ds := us.Children()[0] + ds := p.Children()[0] ds, err := s.prune(ds.(*DataSource)) if err != nil { return nil, err @@ -60,17 +59,16 @@ func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan) (LogicalPlan, err // Union->(UnionScan->DataSource1), (UnionScan->DataSource2) children := make([]LogicalPlan, 0, len(ua.Children())) for _, child := range ua.Children() { - usChild := LogicalUnionScan{}.Init(ua.ctx) - usChild.conditions = us.conditions - usChild.SetChildren(child) - children = append(children, usChild) + us := LogicalUnionScan{conditions: p.conditions}.Init(ua.ctx) + us.SetChildren(child) + children = append(children, us) } ua.SetChildren(children...) return ua, nil } // Only one partition, no union all. - us.SetChildren(ds) - return us, nil + p.SetChildren(ds) + return p, nil default: children := lp.Children() for i, child := range children {