From ddc9a4e5738a5123b07411f2f39818c6d1fdcef5 Mon Sep 17 00:00:00 2001 From: Mingcong Han Date: Mon, 21 Oct 2019 16:29:19 +0800 Subject: [PATCH] planner/cascades: implement ImplementationRule for Selection (#12257) --- planner/cascades/implementation_rules.go | 31 ++++- planner/cascades/integration_test.go | 37 +++--- planner/cascades/optimize.go | 11 +- .../testdata/integration_suite_in.json | 19 +++ .../testdata/integration_suite_out.json | 80 +++++++++++++ planner/cascades/transformation_rules.go | 15 +-- planner/core/plan.go | 8 ++ planner/implementation/base.go | 6 +- planner/implementation/base_test.go | 8 +- planner/implementation/datasource.go | 8 +- planner/implementation/simple_plans.go | 35 +++++- planner/implementation/sort.go | 6 +- planner/memo/expr_iterator.go | 51 +++++--- planner/memo/expr_iterator_test.go | 109 +++++++++++++++++- planner/memo/group.go | 58 ++++++++++ planner/memo/group_test.go | 30 ++++- planner/memo/implementation.go | 2 +- planner/memo/pattern.go | 31 +++-- planner/memo/pattern_test.go | 12 +- 19 files changed, 472 insertions(+), 85 deletions(-) create mode 100644 planner/cascades/testdata/integration_suite_in.json create mode 100644 planner/cascades/testdata/integration_suite_out.json diff --git a/planner/cascades/implementation_rules.go b/planner/cascades/implementation_rules.go index 658cd8ca9fdb2..3c7381adc3bf2 100644 --- a/planner/cascades/implementation_rules.go +++ b/planner/cascades/implementation_rules.go @@ -45,6 +45,9 @@ var defaultImplementationMap = map[memo.Operand][]ImplementationRule{ memo.OperandShow: { &ImplShow{}, }, + memo.OperandSelection: { + &ImplSelection{}, + }, } // ImplTableDual implements LogicalTableDual as PhysicalTableDual. @@ -151,9 +154,35 @@ func (r *ImplShow) OnImplement(expr *memo.GroupExpr, reqProp *property.PhysicalP // TODO(zz-jason): unifying LogicalShow and PhysicalShow to a single // struct. So that we don't need to create a new PhysicalShow object, which - // can help us to reduce the gc presure of golang runtime and improve the + // can help us to reduce the gc pressure of golang runtime and improve the // overall performance. showPhys := plannercore.PhysicalShow{ShowContents: show.ShowContents}.Init(show.SCtx()) showPhys.SetSchema(logicProp.Schema) return impl.NewShowImpl(showPhys), nil } + +// ImplSelection is the implementation rule which implements LogicalSelection +// to PhysicalSelection. +type ImplSelection struct { +} + +// Match implements ImplementationRule Match interface. +func (r *ImplSelection) Match(expr *memo.GroupExpr, prop *property.PhysicalProperty) (matched bool) { + return true +} + +// OnImplement implements ImplementationRule OnImplement interface. +func (r *ImplSelection) OnImplement(expr *memo.GroupExpr, reqProp *property.PhysicalProperty) (memo.Implementation, error) { + logicalSel := expr.ExprNode.(*plannercore.LogicalSelection) + physicalSel := plannercore.PhysicalSelection{ + Conditions: logicalSel.Conditions, + }.Init(logicalSel.SCtx(), expr.Group.Prop.Stats.ScaleByExpectCnt(reqProp.ExpectedCnt), logicalSel.SelectBlockOffset(), reqProp.Clone()) + switch expr.Group.EngineType { + case memo.EngineTiDB: + return impl.NewTiDBSelectionImpl(physicalSel), nil + case memo.EngineTiKV: + return impl.NewTiKVSelectionImpl(physicalSel), nil + default: + return nil, plannercore.ErrInternal.GenWithStack("Unsupported EngineType '%s' for Selection.", expr.Group.EngineType.String()) + } +} diff --git a/planner/cascades/integration_test.go b/planner/cascades/integration_test.go index ce935f48a6e07..a9e312047a9dc 100644 --- a/planner/cascades/integration_test.go +++ b/planner/cascades/integration_test.go @@ -19,12 +19,14 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testutil" ) var _ = Suite(&testIntegrationSuite{}) type testIntegrationSuite struct { - store kv.Storage + store kv.Storage + testData testutil.TestData } func newStoreWithBootstrap() (kv.Storage, error) { @@ -40,9 +42,12 @@ func (s *testIntegrationSuite) SetUpSuite(c *C) { var err error s.store, err = newStoreWithBootstrap() c.Assert(err, IsNil) + s.testData, err = testutil.LoadTestSuiteData("testdata", "integration_suite") + c.Assert(err, IsNil) } func (s *testIntegrationSuite) TearDownSuite(c *C) { + c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil) s.store.Close() } @@ -62,22 +67,22 @@ func (s *testIntegrationSuite) TestPKIsHandleRangeScan(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int primary key, b int)") - tk.MustExec("insert into t values(1,2),(3,4)") + tk.MustExec("insert into t values(1,2),(3,4),(5,6)") tk.MustExec("set session tidb_enable_cascades_planner = 1") - tk.MustQuery("explain select b from t where a > 1").Check(testkit.Rows( - "Projection_8 3333.33 root Column#2", - "└─TableReader_9 3333.33 root data:TableScan_10", - " └─TableScan_10 3333.33 cop[tikv] table:t, range:(1,+inf], keep order:false, stats:pseudo", - )) - tk.MustQuery("select b from t where a > 1").Check(testkit.Rows( - "4", - )) - tk.MustQuery("explain select b from t where a > 1 and a < 3").Check(testkit.Rows( - "Projection_8 2.00 root Column#2", - "└─TableReader_9 2.00 root data:TableScan_10", - " └─TableScan_10 2.00 cop[tikv] table:t, range:(1,3), keep order:false, stats:pseudo", - )) - tk.MustQuery("select b from t where a > 1 and a < 3").Check(testkit.Rows()) + + var input []string + var output []struct { + SQL string + Result []string + } + s.testData.GetTestCases(c, &input, &output) + for i, sql := range input { + s.testData.OnRecord(func() { + output[i].SQL = sql + output[i].Result = s.testData.ConvertRowsToStrings(tk.MustQuery(sql).Rows()) + }) + tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...)) + } } func (s *testIntegrationSuite) TestBasicShow(c *C) { diff --git a/planner/cascades/optimize.go b/planner/cascades/optimize.go index 927a9e5d72fa3..478ab6b286183 100644 --- a/planner/cascades/optimize.go +++ b/planner/cascades/optimize.go @@ -285,7 +285,7 @@ func (opt *Optimizer) implGroup(g *memo.Group, reqPhysProp *property.PhysicalPro } // Handle implementation rules for each equivalent GroupExpr. var cumCost float64 - var childCosts []float64 + var childImpls []memo.Implementation var childPlans []plannercore.PhysicalPlan err := opt.fillGroupStats(g) if err != nil { @@ -300,8 +300,8 @@ func (opt *Optimizer) implGroup(g *memo.Group, reqPhysProp *property.PhysicalPro } for _, impl := range impls { cumCost = 0.0 - childCosts = childCosts[:0] childPlans = childPlans[:0] + childImpls = childImpls[:0] for i, childGroup := range curExpr.Children { childImpl, err := opt.implGroup(childGroup, impl.GetPlan().GetChildReqProps(i), costLimit-cumCost) if err != nil { @@ -311,15 +311,14 @@ func (opt *Optimizer) implGroup(g *memo.Group, reqPhysProp *property.PhysicalPro impl.SetCost(math.MaxFloat64) break } - childCost := childImpl.GetCost() - childCosts = append(childCosts, childCost) - cumCost += childCost + cumCost += childImpl.GetCost() + childImpls = append(childImpls, childImpl) childPlans = append(childPlans, childImpl.GetPlan()) } if impl.GetCost() == math.MaxFloat64 { continue } - cumCost = impl.CalcCost(outCount, childCosts, curExpr.Children...) + cumCost = impl.CalcCost(outCount, childImpls...) if cumCost > costLimit { continue } diff --git a/planner/cascades/testdata/integration_suite_in.json b/planner/cascades/testdata/integration_suite_in.json new file mode 100644 index 0000000000000..bd124ea65a852 --- /dev/null +++ b/planner/cascades/testdata/integration_suite_in.json @@ -0,0 +1,19 @@ +[ + { + "name": "TestPKIsHandleRangeScan", + "cases": [ + "explain select b from t where a > 1", + "select b from t where a > 1", + "explain select b from t where a > 1 and a < 3", + "select b from t where a > 1 and a < 3", + "explain select b from t where a > 1 and b < 6", + "select b from t where a > 1 and b < 6", + "explain select a from t where a * 3 + 1 > 9 and a < 5", + "select a from t where a * 3 + 1 > 9 and a < 5", + // Test TiDBSelection Implementation. + // TODO: change this test case to agg + sel or join + sel when we support them. + "explain select a from t where a * 3 + 1 > 9 and sin(a) < 0.5 and a < 5", + "select a from t where a * 3 + 1 > 9 and sin(a) < 0.5 and a < 5" + ] + } +] diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json new file mode 100644 index 0000000000000..83e3b7b011fe9 --- /dev/null +++ b/planner/cascades/testdata/integration_suite_out.json @@ -0,0 +1,80 @@ +[ + { + "Name": "TestPKIsHandleRangeScan", + "Cases": [ + { + "SQL": "explain select b from t where a > 1", + "Result": [ + "Projection_8 3333.33 root Column#2", + "└─TableReader_9 3333.33 root data:TableScan_10", + " └─TableScan_10 3333.33 cop[tikv] table:t, range:(1,+inf], keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select b from t where a > 1", + "Result": [ + "4", + "6" + ] + }, + { + "SQL": "explain select b from t where a > 1 and a < 3", + "Result": [ + "Projection_8 2.00 root Column#2", + "└─TableReader_9 2.00 root data:TableScan_10", + " └─TableScan_10 2.00 cop[tikv] table:t, range:(1,3), keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select b from t where a > 1 and a < 3", + "Result": null + }, + { + "SQL": "explain select b from t where a > 1 and b < 6", + "Result": [ + "Projection_9 2666.67 root Column#2", + "└─TableReader_10 2666.67 root data:Selection_11", + " └─Selection_11 2666.67 cop[tikv] lt(Column#2, 6)", + " └─TableScan_12 3333.33 cop[tikv] table:t, range:(1,+inf], keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select b from t where a > 1 and b < 6", + "Result": [ + "4" + ] + }, + { + "SQL": "explain select a from t where a * 3 + 1 > 9 and a < 5", + "Result": [ + "Projection_9 4.00 root Column#1", + "└─TableReader_10 4.00 root data:Selection_11", + " └─Selection_11 4.00 cop[tikv] gt(plus(mul(Column#1, 3), 1), 9)", + " └─TableScan_12 5.00 cop[tikv] table:t, range:[-inf,5), keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from t where a * 3 + 1 > 9 and a < 5", + "Result": [ + "3" + ] + }, + { + "SQL": "explain select a from t where a * 3 + 1 > 9 and sin(a) < 0.5 and a < 5", + "Result": [ + "Projection_10 3.20 root Column#1", + "└─Selection_11 3.20 root lt(sin(cast(Column#1)), 0.5)", + " └─TableReader_12 4.00 root data:Selection_13", + " └─Selection_13 4.00 cop[tikv] gt(plus(mul(Column#1, 3), 1), 9)", + " └─TableScan_14 5.00 cop[tikv] table:t, range:[-inf,5), keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from t where a * 3 + 1 > 9 and sin(a) < 0.5 and a < 5", + "Result": [ + "3" + ] + } + ] + } +] diff --git a/planner/cascades/transformation_rules.go b/planner/cascades/transformation_rules.go index bab71410176c7..4a47666a6a102 100644 --- a/planner/cascades/transformation_rules.go +++ b/planner/cascades/transformation_rules.go @@ -61,8 +61,8 @@ func (r *PushSelDownTableScan) GetPattern() *memo.Pattern { if p, ok := patternMap[r]; ok { return p } - ts := memo.NewPattern(memo.OperandTableScan) - p := memo.BuildPattern(memo.OperandSelection, ts) + ts := memo.NewPattern(memo.OperandTableScan, memo.EngineTiKVOrTiFlash) + p := memo.BuildPattern(memo.OperandSelection, memo.EngineTiKVOrTiFlash, ts) patternMap[r] = p return p } @@ -120,9 +120,9 @@ func (r *PushSelDownTableGather) GetPattern() *memo.Pattern { if p, ok := patternMap[r]; ok { return p } - any := memo.NewPattern(memo.OperandAny) - tg := memo.BuildPattern(memo.OperandTableGather, any) - p := memo.BuildPattern(memo.OperandSelection, tg) + any := memo.NewPattern(memo.OperandAny, memo.EngineTiKVOrTiFlash) + tg := memo.BuildPattern(memo.OperandTableGather, memo.EngineTiDBOnly, any) + p := memo.BuildPattern(memo.OperandSelection, memo.EngineTiDBOnly, tg) patternMap[r] = p return p } @@ -150,7 +150,7 @@ func (r *PushSelDownTableGather) OnTransform(old *memo.ExprIter) (newExprs []*me pushedSel := plannercore.LogicalSelection{Conditions: pushed}.Init(sctx, sel.SelectBlockOffset()) pushedSelExpr := memo.NewGroupExpr(pushedSel) pushedSelExpr.Children = append(pushedSelExpr.Children, childGroup) - pushedSelGroup := memo.NewGroupWithSchema(pushedSelExpr, childGroup.Prop.Schema) + pushedSelGroup := memo.NewGroupWithSchema(pushedSelExpr, childGroup.Prop.Schema).SetEngineType(childGroup.EngineType) // The field content of TableGather would not be modified currently, so we // just reference the same tg instead of making a copy of it. // @@ -179,7 +179,7 @@ func (r *EnumeratePaths) GetPattern() *memo.Pattern { if p, ok := patternMap[r]; ok { return p } - p := memo.NewPattern(memo.OperandDataSource) + p := memo.NewPattern(memo.OperandDataSource, memo.EngineTiDBOnly) patternMap[r] = p return p } @@ -195,6 +195,7 @@ func (r *EnumeratePaths) OnTransform(old *memo.ExprIter) (newExprs []*memo.Group gathers := ds.Convert2Gathers() for _, gather := range gathers { expr := convert2GroupExpr(gather) + expr.Children[0].SetEngineType(memo.EngineTiKV) newExprs = append(newExprs, expr) } return newExprs, true, false, nil diff --git a/planner/core/plan.go b/planner/core/plan.go index fe8383b65ba28..b51b27fdf7d0c 100644 --- a/planner/core/plan.go +++ b/planner/core/plan.go @@ -161,6 +161,9 @@ type PhysicalPlan interface { // ResolveIndices resolves the indices for columns. After doing this, the columns can evaluate the rows by their indices. ResolveIndices() error + + // Stats returns the StatsInfo of the plan. + Stats() *property.StatsInfo } type baseLogicalPlan struct { @@ -311,6 +314,11 @@ func (p *basePlan) SelectBlockOffset() int { return p.blockOffset } +// Stats implements Plan Stats interface. +func (p *basePlan) Stats() *property.StatsInfo { + return p.stats +} + // Schema implements Plan Schema interface. func (p *baseLogicalPlan) Schema() *expression.Schema { return p.children[0].Schema() diff --git a/planner/implementation/base.go b/planner/implementation/base.go index ddd87bcd3af2c..8e08efa1a980d 100644 --- a/planner/implementation/base.go +++ b/planner/implementation/base.go @@ -23,10 +23,10 @@ type baseImpl struct { plan plannercore.PhysicalPlan } -func (impl *baseImpl) CalcCost(outCount float64, childCosts []float64, children ...*memo.Group) float64 { +func (impl *baseImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { impl.cost = 0 - for _, childCost := range childCosts { - impl.cost += childCost + for _, child := range children { + impl.cost += child.GetCost() } return impl.cost } diff --git a/planner/implementation/base_test.go b/planner/implementation/base_test.go index f38389f82dbfd..a19b04da3842b 100644 --- a/planner/implementation/base_test.go +++ b/planner/implementation/base_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/tidb/infoschema" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/planner/memo" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/testleak" ) @@ -54,10 +55,9 @@ func (s *testImplSuite) TestBaseImplementation(c *C) { impl := &baseImpl{plan: p} c.Assert(impl.GetPlan(), Equals, p) - childCosts := []float64{5.0} - cost := impl.CalcCost(10, childCosts, nil) - c.Assert(cost, Equals, 5.0) - c.Assert(impl.GetCost(), Equals, 5.0) + cost := impl.CalcCost(10, []memo.Implementation{}...) + c.Assert(cost, Equals, 0.0) + c.Assert(impl.GetCost(), Equals, 0.0) impl.SetCost(6.0) c.Assert(impl.GetCost(), Equals, 6.0) diff --git a/planner/implementation/datasource.go b/planner/implementation/datasource.go index 45b7feeacd7f5..ec5a85d4e8820 100644 --- a/planner/implementation/datasource.go +++ b/planner/implementation/datasource.go @@ -31,7 +31,7 @@ func NewTableDualImpl(dual *plannercore.PhysicalTableDual) *TableDualImpl { } // CalcCost calculates the cost of the table dual Implementation. -func (impl *TableDualImpl) CalcCost(outCount float64, childCosts []float64, children ...*memo.Group) float64 { +func (impl *TableDualImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { return 0 } @@ -52,7 +52,7 @@ func NewTableReaderImpl(reader *plannercore.PhysicalTableReader, hists *statisti } // CalcCost calculates the cost of the table reader Implementation. -func (impl *TableReaderImpl) CalcCost(outCount float64, childCosts []float64, children ...*memo.Group) float64 { +func (impl *TableReaderImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { reader := impl.plan.(*plannercore.PhysicalTableReader) width := impl.tblColHists.GetAvgRowSize(reader.Schema().Columns, false) sessVars := reader.SCtx().GetSessionVars() @@ -62,7 +62,7 @@ func (impl *TableReaderImpl) CalcCost(outCount float64, childCosts []float64, ch // is Min(DistSQLScanConcurrency, numRegionsInvolvedInScan), since we cannot infer // the number of regions involved, we simply use DistSQLScanConcurrency. copIterWorkers := float64(sessVars.DistSQLScanConcurrency) - impl.cost = (networkCost + childCosts[0]) / copIterWorkers + impl.cost = (networkCost + children[0].GetCost()) / copIterWorkers return impl.cost } @@ -85,7 +85,7 @@ func NewTableScanImpl(ts *plannercore.PhysicalTableScan, cols []*expression.Colu } // CalcCost calculates the cost of the table scan Implementation. -func (impl *TableScanImpl) CalcCost(outCount float64, childCosts []float64, children ...*memo.Group) float64 { +func (impl *TableScanImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { ts := impl.plan.(*plannercore.PhysicalTableScan) width := impl.tblColHists.GetAvgRowSize(impl.tblCols, false) sessVars := ts.SCtx().GetSessionVars() diff --git a/planner/implementation/simple_plans.go b/planner/implementation/simple_plans.go index 6b51e3349c8eb..6b0fb085e7284 100644 --- a/planner/implementation/simple_plans.go +++ b/planner/implementation/simple_plans.go @@ -15,9 +15,10 @@ package implementation import ( plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/planner/memo" ) -// ProjectionImpl implementation of PhysicalProjection. +// ProjectionImpl is the implementation of PhysicalProjection. type ProjectionImpl struct { baseImpl } @@ -36,3 +37,35 @@ type ShowImpl struct { func NewShowImpl(show *plannercore.PhysicalShow) *ShowImpl { return &ShowImpl{baseImpl: baseImpl{plan: show}} } + +// TiDBSelectionImpl is the implementation of PhysicalSelection in TiDB layer. +type TiDBSelectionImpl struct { + baseImpl +} + +// CalcCost implements Implementation CalcCost interface. +func (sel *TiDBSelectionImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { + sel.cost = children[0].GetPlan().Stats().RowCount*sel.plan.SCtx().GetSessionVars().CPUFactor + children[0].GetCost() + return sel.cost +} + +// NewTiDBSelectionImpl creates a new TiDBSelectionImpl. +func NewTiDBSelectionImpl(sel *plannercore.PhysicalSelection) *TiDBSelectionImpl { + return &TiDBSelectionImpl{baseImpl{plan: sel}} +} + +// TiKVSelectionImpl is the implementation of PhysicalSelection in TiKV layer. +type TiKVSelectionImpl struct { + baseImpl +} + +// CalcCost implements Implementation CalcCost interface. +func (sel *TiKVSelectionImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { + sel.cost = children[0].GetPlan().Stats().RowCount*sel.plan.SCtx().GetSessionVars().CopCPUFactor + children[0].GetCost() + return sel.cost +} + +// NewTiKVSelectionImpl creates a new TiKVSelectionImpl. +func NewTiKVSelectionImpl(sel *plannercore.PhysicalSelection) *TiKVSelectionImpl { + return &TiKVSelectionImpl{baseImpl{plan: sel}} +} diff --git a/planner/implementation/sort.go b/planner/implementation/sort.go index b4dc76cf7c338..296123fdb7cab 100644 --- a/planner/implementation/sort.go +++ b/planner/implementation/sort.go @@ -31,9 +31,9 @@ func NewSortImpl(sort *plannercore.PhysicalSort) *SortImpl { } // CalcCost calculates the cost of the sort Implementation. -func (impl *SortImpl) CalcCost(outCount float64, childCosts []float64, children ...*memo.Group) float64 { - cnt := math.Min(children[0].Prop.Stats.RowCount, impl.plan.GetChildReqProps(0).ExpectedCnt) +func (impl *SortImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { + cnt := math.Min(children[0].GetPlan().Stats().RowCount, impl.plan.GetChildReqProps(0).ExpectedCnt) sort := impl.plan.(*plannercore.PhysicalSort) - impl.cost = sort.GetCost(cnt) + childCosts[0] + impl.cost = sort.GetCost(cnt) + children[0].GetCost() return impl.cost } diff --git a/planner/memo/expr_iterator.go b/planner/memo/expr_iterator.go index 7e8979976b820..e479c2cdd1dd4 100644 --- a/planner/memo/expr_iterator.go +++ b/planner/memo/expr_iterator.go @@ -24,13 +24,14 @@ type ExprIter struct { *Group *list.Element - // matched indicates whether the current Group expression binded by the + // matched indicates whether the current Group expression bound by the // iterator matches the pattern after the creation or iteration. matched bool - // Operand is the node of the pattern tree. The Operand type of the Group - // expression must be matched with it. - Operand + // Pattern describes the node of pattern tree. + // The Operand type of the Group expression and the EngineType of the Group + // must be matched with it. + *Pattern // Children is used to iterate the child expressions. Children []*ExprIter @@ -62,15 +63,18 @@ func (iter *ExprIter) Next() (found bool) { // Otherwise, iterate itself to find more matched equivalent expressions. for elem := iter.Element.Next(); elem != nil; elem = elem.Next() { expr := elem.Value.(*GroupExpr) - exprOperand := GetOperand(expr.ExprNode) - if !iter.Operand.Match(exprOperand) { + if !iter.Operand.Match(GetOperand(expr.ExprNode)) { // All the Equivalents which have the same Operand are continuously // stored in the list. Once the current equivalent can not Match // the Operand, the rest can not, either. return false } + if len(iter.Children) == 0 { + iter.Element = elem + return true + } if len(iter.Children) != len(expr.Children) { continue } @@ -102,17 +106,34 @@ func (iter *ExprIter) Matched() bool { func (iter *ExprIter) Reset() (findMatch bool) { defer func() { iter.matched = findMatch }() - if iter.Operand == OperandAny { + if iter.Pattern.MatchOperandAny(iter.Group.EngineType) { return true } for elem := iter.Group.GetFirstElem(iter.Operand); elem != nil; elem = elem.Next() { expr := elem.Value.(*GroupExpr) - exprOperand := GetOperand(expr.ExprNode) - if !iter.Operand.Match(exprOperand) { + + if !iter.Pattern.Match(GetOperand(expr.ExprNode), expr.Group.EngineType) { break } + // The leaf node of the pattern tree might not be an OperandAny or a XXXScan. + // We allow the patterns like: Selection -> Projection. + // For example, we have such a memo: + // Group#1 + // Selection_0 input:[Group#2] + // Group#2 + // Projection_1 input:[Group#3] + // Projection_2 input:[Group#4] + // Group#3 + // ..... + // For the pattern above, we will match it twice: `Selection_0->Projection_1` + // and `Selection_0->Projection_2`. So if the iterator has no children, we can safely return + // the element here. + if len(iter.Children) == 0 { + iter.Element = elem + return true + } if len(expr.Children) != len(iter.Children) { continue } @@ -141,7 +162,7 @@ func (iter *ExprIter) GetExpr() *GroupExpr { // NewExprIterFromGroupElem creates the iterator on the Group Element. func NewExprIterFromGroupElem(elem *list.Element, p *Pattern) *ExprIter { expr := elem.Value.(*GroupExpr) - if !p.Operand.Match(GetOperand(expr.ExprNode)) { + if !p.Match(GetOperand(expr.ExprNode), expr.Group.EngineType) { return nil } iter := newExprIterFromGroupExpr(expr, p) @@ -153,10 +174,10 @@ func NewExprIterFromGroupElem(elem *list.Element, p *Pattern) *ExprIter { // newExprIterFromGroupExpr creates the iterator on the Group expression. func newExprIterFromGroupExpr(expr *GroupExpr, p *Pattern) *ExprIter { - if len(p.Children) != len(expr.Children) { + if len(p.Children) != 0 && len(p.Children) != len(expr.Children) { return nil } - iter := &ExprIter{Operand: p.Operand, matched: true} + iter := &ExprIter{Pattern: p, matched: true} for i := range p.Children { childIter := newExprIterFromGroup(expr.Children[i], p.Children[i]) if childIter == nil { @@ -169,12 +190,12 @@ func newExprIterFromGroupExpr(expr *GroupExpr, p *Pattern) *ExprIter { // newExprIterFromGroup creates the iterator on the Group. func newExprIterFromGroup(g *Group, p *Pattern) *ExprIter { - if p.Operand == OperandAny { - return &ExprIter{Group: g, Operand: OperandAny, matched: true} + if p.MatchOperandAny(g.EngineType) { + return &ExprIter{Group: g, Pattern: p, matched: true} } for elem := g.GetFirstElem(p.Operand); elem != nil; elem = elem.Next() { expr := elem.Value.(*GroupExpr) - if !p.Operand.Match(GetOperand(expr.ExprNode)) { + if !p.Match(GetOperand(expr.ExprNode), g.EngineType) { return nil } iter := newExprIterFromGroupExpr(expr, p) diff --git a/planner/memo/expr_iterator_test.go b/planner/memo/expr_iterator_test.go index 6cb91ab0e1bb2..fceaa75f4580a 100644 --- a/planner/memo/expr_iterator_test.go +++ b/planner/memo/expr_iterator_test.go @@ -34,7 +34,7 @@ func (s *testMemoSuite) TestNewExprIterFromGroupElem(c *C) { expr.Children = append(expr.Children, g1) g2 := NewGroupWithSchema(expr, nil) - pattern := BuildPattern(OperandJoin, BuildPattern(OperandProjection), BuildPattern(OperandSelection)) + pattern := BuildPattern(OperandJoin, EngineAll, BuildPattern(OperandProjection, EngineAll), BuildPattern(OperandSelection, EngineAll)) iter := NewExprIterFromGroupElem(g2.Equivalents.Front(), pattern) c.Assert(iter, NotNil) @@ -75,7 +75,7 @@ func (s *testMemoSuite) TestExprIterNext(c *C) { expr.Children = append(expr.Children, g1) g2 := NewGroupWithSchema(expr, nil) - pattern := BuildPattern(OperandJoin, BuildPattern(OperandProjection), BuildPattern(OperandSelection)) + pattern := BuildPattern(OperandJoin, EngineAll, BuildPattern(OperandProjection, EngineAll), BuildPattern(OperandSelection, EngineAll)) iter := NewExprIterFromGroupElem(g2.Equivalents.Front(), pattern) c.Assert(iter, NotNil) @@ -135,9 +135,9 @@ func (s *testMemoSuite) TestExprIterReset(c *C) { sel3.Children = append(sel3.Children, g2) // create a pattern: join(proj, sel(limit)) - lhsPattern := BuildPattern(OperandProjection) - rhsPattern := BuildPattern(OperandSelection, BuildPattern(OperandLimit)) - pattern := BuildPattern(OperandJoin, lhsPattern, rhsPattern) + lhsPattern := BuildPattern(OperandProjection, EngineAll) + rhsPattern := BuildPattern(OperandSelection, EngineAll, BuildPattern(OperandLimit, EngineAll)) + pattern := BuildPattern(OperandJoin, EngineAll, lhsPattern, rhsPattern) // create expression iterator for the pattern on join iter := NewExprIterFromGroupElem(g3.Equivalents.Front(), pattern) @@ -169,3 +169,102 @@ func (s *testMemoSuite) TestExprIterReset(c *C) { c.Assert(count, Equals, 18) } + +func countMatchedIter(group *Group, pattern *Pattern) int { + count := 0 + for elem := group.Equivalents.Front(); elem != nil; elem = elem.Next() { + iter := NewExprIterFromGroupElem(elem, pattern) + if iter == nil { + continue + } + for ; iter.Matched(); iter.Next() { + count++ + } + } + return count +} + +func (s *testMemoSuite) TestExprIterWithEngineType(c *C) { + g1 := NewGroupWithSchema(NewGroupExpr(plannercore.LogicalSelection{}.Init(s.sctx, 0)), nil).SetEngineType(EngineTiFlash) + g1.Insert(NewGroupExpr(plannercore.LogicalLimit{}.Init(s.sctx, 0))) + g1.Insert(NewGroupExpr(plannercore.LogicalProjection{}.Init(s.sctx, 0))) + g1.Insert(NewGroupExpr(plannercore.LogicalLimit{}.Init(s.sctx, 0))) + + g2 := NewGroupWithSchema(NewGroupExpr(plannercore.LogicalSelection{}.Init(s.sctx, 0)), nil).SetEngineType(EngineTiKV) + g2.Insert(NewGroupExpr(plannercore.LogicalLimit{}.Init(s.sctx, 0))) + g2.Insert(NewGroupExpr(plannercore.LogicalProjection{}.Init(s.sctx, 0))) + g2.Insert(NewGroupExpr(plannercore.LogicalLimit{}.Init(s.sctx, 0))) + + flashGather := NewGroupExpr(plannercore.TableGather{}.Init(s.sctx, 0)) + flashGather.Children = append(flashGather.Children, g1) + g3 := NewGroupWithSchema(flashGather, nil).SetEngineType(EngineTiDB) + + tikvGather := NewGroupExpr(plannercore.TableGather{}.Init(s.sctx, 0)) + tikvGather.Children = append(tikvGather.Children, g2) + g3.Insert(tikvGather) + + join := NewGroupExpr(plannercore.LogicalJoin{}.Init(s.sctx, 0)) + join.Children = append(join.Children, g3, g3) + g4 := NewGroupWithSchema(join, nil).SetEngineType(EngineTiDB) + + // The Groups look like this: + // Group 4 + // Join input:[Group3, Group3] + // Group 3 + // TableGather input:[Group2] EngineTiKV + // TableGather input:[Group1] EngineTiFlash + // Group 2 + // Selection + // Projection + // Limit + // Limit + // Group 1 + // Selection + // Projection + // Limit + // Limit + + p0 := BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiKVOnly)) + c.Assert(countMatchedIter(g3, p0), Equals, 2) + p1 := BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiFlashOnly)) + c.Assert(countMatchedIter(g3, p1), Equals, 2) + p2 := BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiKVOrTiFlash)) + c.Assert(countMatchedIter(g3, p2), Equals, 4) + p3 := BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandSelection, EngineTiFlashOnly)) + c.Assert(countMatchedIter(g3, p3), Equals, 1) + p4 := BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandProjection, EngineTiKVOnly)) + c.Assert(countMatchedIter(g3, p4), Equals, 1) + + p5 := BuildPattern( + OperandJoin, + EngineTiDBOnly, + BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiKVOnly)), + BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiKVOnly)), + ) + c.Assert(countMatchedIter(g4, p5), Equals, 4) + p6 := BuildPattern( + OperandJoin, + EngineTiDBOnly, + BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiFlashOnly)), + BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiKVOnly)), + ) + c.Assert(countMatchedIter(g4, p6), Equals, 4) + p7 := BuildPattern( + OperandJoin, + EngineTiDBOnly, + BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiKVOrTiFlash)), + BuildPattern(OperandTableGather, EngineTiDBOnly, BuildPattern(OperandLimit, EngineTiKVOrTiFlash)), + ) + c.Assert(countMatchedIter(g4, p7), Equals, 16) + + // This is not a test case for EngineType. This case is to test + // the Pattern without a leaf AnyOperand. It is more efficient to + // test it here. + p8 := BuildPattern( + OperandJoin, + EngineTiDBOnly, + BuildPattern(OperandTableGather, EngineTiDBOnly), + BuildPattern(OperandTableGather, EngineTiDBOnly), + ) + c.Assert(countMatchedIter(g4, p8), Equals, 4) +} diff --git a/planner/memo/group.go b/planner/memo/group.go index 301f1d1a5dce4..5a810f36efc9e 100644 --- a/planner/memo/group.go +++ b/planner/memo/group.go @@ -21,6 +21,55 @@ import ( "github.com/pingcap/tidb/planner/property" ) +// EngineType is determined by whether it's above or below `Gather`s. +// Plan will choose the different engine to be implemented/executed on according to its EngineType. +// Different engine may support different operators with different cost, so we should design +// different transformation and implementation rules for each engine. +type EngineType uint + +const ( + // EngineTiDB stands for groups which is above `Gather`s and will be executed in TiDB layer. + EngineTiDB EngineType = 1 << iota + // EngineTiKV stands for groups which is below `Gather`s and will be executed in TiKV layer. + EngineTiKV + // EngineTiFlash stands for groups which is below `Gather`s and will be executed in TiFlash layer. + EngineTiFlash +) + +// EngineTypeSet is the bit set of EngineTypes. +type EngineTypeSet uint + +const ( + // EngineTiDBOnly is the EngineTypeSet for EngineTiDB only. + EngineTiDBOnly = EngineTypeSet(EngineTiDB) + // EngineTiKVOnly is the EngineTypeSet for EngineTiKV only. + EngineTiKVOnly = EngineTypeSet(EngineTiKV) + // EngineTiFlashOnly is the EngineTypeSet for EngineTiFlash only. + EngineTiFlashOnly = EngineTypeSet(EngineTiFlash) + // EngineTiKVOrTiFlash is the EngineTypeSet for (EngineTiKV | EngineTiFlash). + EngineTiKVOrTiFlash = EngineTypeSet(EngineTiKV | EngineTiFlash) + // EngineAll is the EngineTypeSet for all of the EngineTypes. + EngineAll = EngineTypeSet(EngineTiDB | EngineTiKV | EngineTiFlash) +) + +// Contains checks whether the EngineTypeSet contains the EngineType. +func (e EngineTypeSet) Contains(tp EngineType) bool { + return uint(e)&uint(tp) != 0 +} + +// String implements fmt.Stringer interface. +func (e EngineType) String() string { + switch e { + case EngineTiDB: + return "EngineTiDB" + case EngineTiKV: + return "EngineTiKV" + case EngineTiFlash: + return "EngineTiFlash" + } + return "UnknownEngineType" +} + // Group is short for expression Group, which is used to store all the // logically equivalent expressions. It's a set of GroupExpr. type Group struct { @@ -34,6 +83,8 @@ type Group struct { ImplMap map[string]Implementation Prop *property.LogicalProperty + + EngineType EngineType } // NewGroupWithSchema creates a new Group with given schema. @@ -45,11 +96,18 @@ func NewGroupWithSchema(e *GroupExpr, s *expression.Schema) *Group { FirstExpr: make(map[Operand]*list.Element), ImplMap: make(map[string]Implementation), Prop: prop, + EngineType: EngineTiDB, } g.Insert(e) return g } +// SetEngineType sets the engine type of the group. +func (g *Group) SetEngineType(e EngineType) *Group { + g.EngineType = e + return g +} + // FingerPrint returns the unique fingerprint of the Group. func (g *Group) FingerPrint() string { if g.SelfFingerprint == "" { diff --git a/planner/memo/group_test.go b/planner/memo/group_test.go index bc15913606bbb..acc1e30a0529e 100644 --- a/planner/memo/group_test.go +++ b/planner/memo/group_test.go @@ -117,10 +117,10 @@ type fakeImpl struct { plan plannercore.PhysicalPlan } -func (impl *fakeImpl) CalcCost(float64, []float64, ...*Group) float64 { return 0 } -func (impl *fakeImpl) SetCost(float64) {} -func (impl *fakeImpl) GetCost() float64 { return 0 } -func (impl *fakeImpl) GetPlan() plannercore.PhysicalPlan { return impl.plan } +func (impl *fakeImpl) CalcCost(float64, ...Implementation) float64 { return 0 } +func (impl *fakeImpl) SetCost(float64) {} +func (impl *fakeImpl) GetCost() float64 { return 0 } +func (impl *fakeImpl) GetPlan() plannercore.PhysicalPlan { return impl.plan } func (s *testMemoSuite) TestGetInsertGroupImpl(c *C) { g := NewGroupWithSchema(NewGroupExpr(plannercore.LogicalLimit{}.Init(s.sctx, 0)), nil) @@ -139,3 +139,25 @@ func (s *testMemoSuite) TestGetInsertGroupImpl(c *C) { newImpl = g.GetImpl(orderProp) c.Assert(newImpl, IsNil) } + +func (s *testMemoSuite) TestEngineTypeSet(c *C) { + c.Assert(EngineAll.Contains(EngineTiDB), IsTrue) + c.Assert(EngineAll.Contains(EngineTiKV), IsTrue) + c.Assert(EngineAll.Contains(EngineTiFlash), IsTrue) + + c.Assert(EngineTiDBOnly.Contains(EngineTiDB), IsTrue) + c.Assert(EngineTiDBOnly.Contains(EngineTiKV), IsFalse) + c.Assert(EngineTiDBOnly.Contains(EngineTiFlash), IsFalse) + + c.Assert(EngineTiKVOnly.Contains(EngineTiDB), IsFalse) + c.Assert(EngineTiKVOnly.Contains(EngineTiKV), IsTrue) + c.Assert(EngineTiKVOnly.Contains(EngineTiFlash), IsFalse) + + c.Assert(EngineTiFlashOnly.Contains(EngineTiDB), IsFalse) + c.Assert(EngineTiFlashOnly.Contains(EngineTiKV), IsFalse) + c.Assert(EngineTiFlashOnly.Contains(EngineTiFlash), IsTrue) + + c.Assert(EngineTiKVOrTiFlash.Contains(EngineTiDB), IsFalse) + c.Assert(EngineTiKVOrTiFlash.Contains(EngineTiKV), IsTrue) + c.Assert(EngineTiKVOrTiFlash.Contains(EngineTiFlash), IsTrue) +} diff --git a/planner/memo/implementation.go b/planner/memo/implementation.go index dc8d75aaa7d75..3eb626b3b0d98 100644 --- a/planner/memo/implementation.go +++ b/planner/memo/implementation.go @@ -19,7 +19,7 @@ import ( // Implementation defines the interface for cost of physical plan. type Implementation interface { - CalcCost(outCount float64, childCosts []float64, children ...*Group) float64 + CalcCost(outCount float64, children ...Implementation) float64 SetCost(cost float64) GetCost() float64 GetPlan() plannercore.PhysicalPlan diff --git a/planner/memo/pattern.go b/planner/memo/pattern.go index f9f68294464de..fff3529bd4fba 100644 --- a/planner/memo/pattern.go +++ b/planner/memo/pattern.go @@ -118,17 +118,30 @@ func (o Operand) Match(t Operand) bool { return false } -// Pattern defines the Match pattern for a rule. -// It describes a piece of logical expression. -// It's a tree-like structure and each node in the tree is an Operand. +// Pattern defines the match pattern for a rule. It's a tree-like structure +// which is a piece of a logical expression. Each node in the Pattern tree is +// defined by an Operand and EngineType pair. type Pattern struct { Operand + EngineTypeSet Children []*Pattern } -// NewPattern creats a pattern node according to the Operand. -func NewPattern(operand Operand) *Pattern { - return &Pattern{Operand: operand} +// Match checks whether the EngineTypeSet contains the given EngineType +// and whether the two Operands match. +func (p *Pattern) Match(o Operand, e EngineType) bool { + return p.EngineTypeSet.Contains(e) && p.Operand.Match(o) +} + +// MatchOperandAny checks whether the pattern's Operand is OperandAny +// and the EngineTypeSet contains the given EngineType. +func (p *Pattern) MatchOperandAny(e EngineType) bool { + return p.EngineTypeSet.Contains(e) && p.Operand == OperandAny +} + +// NewPattern creates a pattern node according to the Operand and EngineType. +func NewPattern(operand Operand, engineTypeSet EngineTypeSet) *Pattern { + return &Pattern{Operand: operand, EngineTypeSet: engineTypeSet} } // SetChildren sets the Children information for a pattern node. @@ -136,10 +149,10 @@ func (p *Pattern) SetChildren(children ...*Pattern) { p.Children = children } -// BuildPattern builds a Pattern from Operand and child Patterns. +// BuildPattern builds a Pattern from Operand, EngineType and child Patterns. // Used in GetPattern() of Transformation interface to generate a Pattern. -func BuildPattern(operand Operand, children ...*Pattern) *Pattern { - p := &Pattern{Operand: operand} +func BuildPattern(operand Operand, engineTypeSet EngineTypeSet, children ...*Pattern) *Pattern { + p := &Pattern{Operand: operand, EngineTypeSet: engineTypeSet} p.Children = children return p } diff --git a/planner/memo/pattern_test.go b/planner/memo/pattern_test.go index 526f82cd90deb..fb44a899fc0d9 100644 --- a/planner/memo/pattern_test.go +++ b/planner/memo/pattern_test.go @@ -60,24 +60,24 @@ func (s *testMemoSuite) TestOperandMatch(c *C) { } func (s *testMemoSuite) TestNewPattern(c *C) { - p := NewPattern(OperandAny) + p := NewPattern(OperandAny, EngineAll) c.Assert(p.Operand, Equals, OperandAny) c.Assert(p.Children, IsNil) - p = NewPattern(OperandJoin) + p = NewPattern(OperandJoin, EngineAll) c.Assert(p.Operand, Equals, OperandJoin) c.Assert(p.Children, IsNil) } func (s *testMemoSuite) TestPatternSetChildren(c *C) { - p := NewPattern(OperandAny) - p.SetChildren(NewPattern(OperandLimit)) + p := NewPattern(OperandAny, EngineAll) + p.SetChildren(NewPattern(OperandLimit, EngineAll)) c.Assert(len(p.Children), Equals, 1) c.Assert(p.Children[0].Operand, Equals, OperandLimit) c.Assert(p.Children[0].Children, IsNil) - p = NewPattern(OperandJoin) - p.SetChildren(NewPattern(OperandProjection), NewPattern(OperandSelection)) + p = NewPattern(OperandJoin, EngineAll) + p.SetChildren(NewPattern(OperandProjection, EngineAll), NewPattern(OperandSelection, EngineAll)) c.Assert(len(p.Children), Equals, 2) c.Assert(p.Children[0].Operand, Equals, OperandProjection) c.Assert(p.Children[0].Children, IsNil)