From 76aae0d5c594f538af62caa883c73188a44170c4 Mon Sep 17 00:00:00 2001 From: Yifan Xu <30385241+xuyifangreeneyes@users.noreply.github.com> Date: Sat, 25 Dec 2021 18:21:48 +0800 Subject: [PATCH] planner: change predicateColumnCollector to columnStatsUsageCollector and collect histogram-needed columns (#30671) --- expression/util.go | 8 +- planner/core/collect_column_stats_usage.go | 296 ++++++++------- .../core/collect_column_stats_usage_test.go | 353 ++++++++++++------ planner/core/logical_plan_test.go | 18 +- planner/core/logical_plans.go | 2 +- 5 files changed, 425 insertions(+), 252 deletions(-) diff --git a/expression/util.go b/expression/util.go index 3a793cfddc640..a6d9ef0d6169d 100644 --- a/expression/util.go +++ b/expression/util.go @@ -166,8 +166,8 @@ func extractColumns(result []*Column, expr Expression, filter func(*Column) bool return result } -// ExtractColumnsAndCorColumns extracts columns and correlated columns from `expr` and append them to `result`. -func ExtractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { +// extractColumnsAndCorColumns extracts columns and correlated columns from `expr` and append them to `result`. +func extractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { switch v := expr.(type) { case *Column: result = append(result, v) @@ -175,7 +175,7 @@ func ExtractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { result = append(result, &v.Column) case *ScalarFunction: for _, arg := range v.GetArgs() { - result = ExtractColumnsAndCorColumns(result, arg) + result = extractColumnsAndCorColumns(result, arg) } } return result @@ -184,7 +184,7 @@ func ExtractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { // ExtractColumnsAndCorColumnsFromExpressions extracts columns and correlated columns from expressions and append them to `result`. func ExtractColumnsAndCorColumnsFromExpressions(result []*Column, list []Expression) []*Column { for _, expr := range list { - result = ExtractColumnsAndCorColumns(result, expr) + result = extractColumnsAndCorColumns(result, expr) } return result } diff --git a/planner/core/collect_column_stats_usage.go b/planner/core/collect_column_stats_usage.go index 6396b1ddad34f..0305223074984 100644 --- a/planner/core/collect_column_stats_usage.go +++ b/planner/core/collect_column_stats_usage.go @@ -19,27 +19,49 @@ import ( "github.com/pingcap/tidb/parser/model" ) -// predicateColumnCollector collects predicate columns from logical plan. Predicate columns are the columns whose statistics -// are utilized when making query plans, which usually occur in where conditions, join conditions and so on. -type predicateColumnCollector struct { - // colMap maps expression.Column.UniqueID to the table columns whose statistics are utilized to calculate statistics of the column. - colMap map[int64]map[model.TableColumnID]struct{} +const ( + collectPredicateColumns uint64 = 1 << iota + collectHistNeededColumns +) + +// columnStatsUsageCollector collects predicate columns and/or histogram-needed columns from logical plan. +// Predicate columns are the columns whose statistics are utilized when making query plans, which usually occur in where conditions, join conditions and so on. +// Histogram-needed columns are the columns whose histograms are utilized when making query plans, which usually occur in the conditions pushed down to DataSource. +// The set of histogram-needed columns is the subset of that of predicate columns. +type columnStatsUsageCollector struct { + // collectMode indicates whether to collect predicate columns and/or histogram-needed columns + collectMode uint64 // predicateCols records predicate columns. predicateCols map[model.TableColumnID]struct{} + // colMap maps expression.Column.UniqueID to the table columns whose statistics may be utilized to calculate statistics of the column. + // It is used for collecting predicate columns. + // For example, in `select count(distinct a, b) as e from t`, the count of column `e` is calculated as `max(ndv(t.a), ndv(t.b))` if + // we don't know `ndv(t.a, t.b)`(see (*LogicalAggregation).DeriveStats and getColsNDV for details). So when calculating the statistics + // of column `e`, we may use the statistics of column `t.a` and `t.b`. + colMap map[int64]map[model.TableColumnID]struct{} + // histNeededCols records histogram-needed columns + histNeededCols map[model.TableColumnID]struct{} // cols is used to store columns collected from expressions and saves some allocation. cols []*expression.Column } -func newPredicateColumnCollector() *predicateColumnCollector { - return &predicateColumnCollector{ - colMap: make(map[int64]map[model.TableColumnID]struct{}), - predicateCols: make(map[model.TableColumnID]struct{}), +func newColumnStatsUsageCollector(collectMode uint64) *columnStatsUsageCollector { + collector := &columnStatsUsageCollector{ + collectMode: collectMode, // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. cols: make([]*expression.Column, 0, 8), } + if collectMode&collectPredicateColumns != 0 { + collector.predicateCols = make(map[model.TableColumnID]struct{}) + collector.colMap = make(map[int64]map[model.TableColumnID]struct{}) + } + if collectMode&collectHistNeededColumns != 0 { + collector.histNeededCols = make(map[model.TableColumnID]struct{}) + } + return collector } -func (c *predicateColumnCollector) addPredicateColumn(col *expression.Column) { +func (c *columnStatsUsageCollector) addPredicateColumn(col *expression.Column) { tblColIDs, ok := c.colMap[col.UniqueID] if !ok { // It may happen if some leaf of logical plan is LogicalMemTable/LogicalShow/LogicalShowDDLJobs. @@ -50,21 +72,14 @@ func (c *predicateColumnCollector) addPredicateColumn(col *expression.Column) { } } -func (c *predicateColumnCollector) addPredicateColumnsFromExpression(expr expression.Expression) { - cols := expression.ExtractColumnsAndCorColumns(c.cols[:0], expr) - for _, col := range cols { - c.addPredicateColumn(col) - } -} - -func (c *predicateColumnCollector) addPredicateColumnsFromExpressions(list []expression.Expression) { +func (c *columnStatsUsageCollector) addPredicateColumnsFromExpressions(list []expression.Expression) { cols := expression.ExtractColumnsAndCorColumnsFromExpressions(c.cols[:0], list) for _, col := range cols { c.addPredicateColumn(col) } } -func (c *predicateColumnCollector) updateColMap(col *expression.Column, relatedCols []*expression.Column) { +func (c *columnStatsUsageCollector) updateColMap(col *expression.Column, relatedCols []*expression.Column) { if _, ok := c.colMap[col.UniqueID]; !ok { c.colMap[col.UniqueID] = map[model.TableColumnID]struct{}{} } @@ -80,15 +95,11 @@ func (c *predicateColumnCollector) updateColMap(col *expression.Column, relatedC } } -func (c *predicateColumnCollector) updateColMapFromExpression(col *expression.Column, expr expression.Expression) { - c.updateColMap(col, expression.ExtractColumnsAndCorColumns(c.cols[:0], expr)) -} - -func (c *predicateColumnCollector) updateColMapFromExpressions(col *expression.Column, list []expression.Expression) { +func (c *columnStatsUsageCollector) updateColMapFromExpressions(col *expression.Column, list []expression.Expression) { c.updateColMap(col, expression.ExtractColumnsAndCorColumnsFromExpressions(c.cols[:0], list)) } -func (ds *DataSource) updateColMapAndAddPredicateColumns(c *predicateColumnCollector) { +func (c *columnStatsUsageCollector) collectPredicateColumnsForDataSource(ds *DataSource) { tblID := ds.TableInfo().ID for _, col := range ds.Schema().Columns { tblColID := model.TableColumnID{TableID: tblID, ColumnID: col.ID} @@ -98,7 +109,7 @@ func (ds *DataSource) updateColMapAndAddPredicateColumns(c *predicateColumnColle c.addPredicateColumnsFromExpressions(ds.pushedDownConds) } -func (p *LogicalJoin) updateColMapAndAddPredicateColumns(c *predicateColumnCollector) { +func (c *columnStatsUsageCollector) collectPredicateColumnsForJoin(p *LogicalJoin) { // The only schema change is merging two schemas so there is no new column. // Assume statistics of all the columns in EqualConditions/LeftConditions/RightConditions/OtherConditions are needed. exprs := make([]expression.Expression, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) @@ -117,7 +128,7 @@ func (p *LogicalJoin) updateColMapAndAddPredicateColumns(c *predicateColumnColle c.addPredicateColumnsFromExpressions(exprs) } -func (p *LogicalUnionAll) updateColMapAndAddPredicateColumns(c *predicateColumnCollector) { +func (c *columnStatsUsageCollector) collectPredicateColumnsForUnionAll(p *LogicalUnionAll) { // statistics of the ith column of UnionAll come from statistics of the ith column of each child. schemas := make([]*expression.Schema, 0, len(p.Children())) relatedCols := make([]*expression.Column, 0, len(p.Children())) @@ -133,120 +144,143 @@ func (p *LogicalUnionAll) updateColMapAndAddPredicateColumns(c *predicateColumnC } } -func (c *predicateColumnCollector) collectFromPlan(lp LogicalPlan) { +func (c *columnStatsUsageCollector) addHistNeededColumns(ds *DataSource) { + columns := expression.ExtractColumnsFromExpressions(c.cols[:0], ds.pushedDownConds, nil) + for _, col := range columns { + tblColID := model.TableColumnID{TableID: ds.physicalTableID, ColumnID: col.ID} + c.histNeededCols[tblColID] = struct{}{} + } +} + +func (c *columnStatsUsageCollector) collectFromPlan(lp LogicalPlan) { for _, child := range lp.Children() { c.collectFromPlan(child) } - switch x := lp.(type) { - case *DataSource: - x.updateColMapAndAddPredicateColumns(c) - case *LogicalIndexScan: - x.Source.updateColMapAndAddPredicateColumns(c) - // TODO: Is it redundant to add predicate columns from LogicalIndexScan.AccessConds? Is LogicalIndexScan.AccessConds a subset of LogicalIndexScan.Source.pushedDownConds. - c.addPredicateColumnsFromExpressions(x.AccessConds) - case *LogicalTableScan: - x.Source.updateColMapAndAddPredicateColumns(c) - // TODO: Is it redundant to add predicate columns from LogicalTableScan.AccessConds? Is LogicalTableScan.AccessConds a subset of LogicalTableScan.Source.pushedDownConds. - c.addPredicateColumnsFromExpressions(x.AccessConds) - case *TiKVSingleGather: - // TODO: Is it redundant? - x.Source.updateColMapAndAddPredicateColumns(c) - case *LogicalProjection: - // Schema change from children to self. - schema := x.Schema() - for i, expr := range x.Exprs { - c.updateColMapFromExpression(schema.Columns[i], expr) - } - case *LogicalSelection: - // Though the conditions in LogicalSelection are complex conditions which cannot be pushed down to DataSource, we still - // regard statistics of the columns in the conditions as needed. - c.addPredicateColumnsFromExpressions(x.Conditions) - case *LogicalAggregation: - // Just assume statistics of all the columns in GroupByItems are needed. - c.addPredicateColumnsFromExpressions(x.GroupByItems) - // Schema change from children to self. - schema := x.Schema() - for i, aggFunc := range x.AggFuncs { - c.updateColMapFromExpressions(schema.Columns[i], aggFunc.Args) - } - case *LogicalWindow: - // Statistics of the columns in LogicalWindow.PartitionBy are used in optimizeByShuffle4Window. - // It seems that we don't use statistics of the columns in LogicalWindow.OrderBy currently? - for _, item := range x.PartitionBy { - c.addPredicateColumn(item.Col) - } - // Schema change from children to self. - windowColumns := x.GetWindowResultColumns() - for i, col := range windowColumns { - c.updateColMapFromExpressions(col, x.WindowFuncDescs[i].Args) - } - case *LogicalJoin: - x.updateColMapAndAddPredicateColumns(c) - case *LogicalApply: - x.updateColMapAndAddPredicateColumns(c) - // Assume statistics of correlated columns are needed. - // Correlated columns can be found in LogicalApply.Children()[0].Schema(). Since we already visit LogicalApply.Children()[0], - // correlated columns must have existed in predicateColumnCollector.colMap. - for _, corCols := range x.CorCols { - c.addPredicateColumn(&corCols.Column) - } - case *LogicalSort: - // Assume statistics of all the columns in ByItems are needed. - for _, item := range x.ByItems { - c.addPredicateColumnsFromExpression(item.Expr) - } - case *LogicalTopN: - // Assume statistics of all the columns in ByItems are needed. - for _, item := range x.ByItems { - c.addPredicateColumnsFromExpression(item.Expr) - } - case *LogicalUnionAll: - x.updateColMapAndAddPredicateColumns(c) - case *LogicalPartitionUnionAll: - x.updateColMapAndAddPredicateColumns(c) - case *LogicalCTE: - // Visit seedPartLogicalPlan and recursivePartLogicalPlan first. - c.collectFromPlan(x.cte.seedPartLogicalPlan) - if x.cte.recursivePartLogicalPlan != nil { - c.collectFromPlan(x.cte.recursivePartLogicalPlan) - } - // Schema change from seedPlan/recursivePlan to self. - columns := x.Schema().Columns - seedColumns := x.cte.seedPartLogicalPlan.Schema().Columns - var recursiveColumns []*expression.Column - if x.cte.recursivePartLogicalPlan != nil { - recursiveColumns = x.cte.recursivePartLogicalPlan.Schema().Columns - } - relatedCols := make([]*expression.Column, 0, 2) - for i, col := range columns { - relatedCols = append(relatedCols[:0], seedColumns[i]) - if recursiveColumns != nil { - relatedCols = append(relatedCols, recursiveColumns[i]) + if c.collectMode&collectPredicateColumns != 0 { + switch x := lp.(type) { + case *DataSource: + c.collectPredicateColumnsForDataSource(x) + case *LogicalIndexScan: + c.collectPredicateColumnsForDataSource(x.Source) + c.addPredicateColumnsFromExpressions(x.AccessConds) + case *LogicalTableScan: + c.collectPredicateColumnsForDataSource(x.Source) + c.addPredicateColumnsFromExpressions(x.AccessConds) + case *LogicalProjection: + // Schema change from children to self. + schema := x.Schema() + for i, expr := range x.Exprs { + c.updateColMapFromExpressions(schema.Columns[i], []expression.Expression{expr}) } - c.updateColMap(col, relatedCols) - } - // If IsDistinct is true, then we use getColsNDV to calculate row count(see (*LogicalCTE).DeriveStat). In this case - // statistics of all the columns are needed. - if x.cte.IsDistinct { - for _, col := range columns { - c.addPredicateColumn(col) + case *LogicalSelection: + // Though the conditions in LogicalSelection are complex conditions which cannot be pushed down to DataSource, we still + // regard statistics of the columns in the conditions as needed. + c.addPredicateColumnsFromExpressions(x.Conditions) + case *LogicalAggregation: + // Just assume statistics of all the columns in GroupByItems are needed. + c.addPredicateColumnsFromExpressions(x.GroupByItems) + // Schema change from children to self. + schema := x.Schema() + for i, aggFunc := range x.AggFuncs { + c.updateColMapFromExpressions(schema.Columns[i], aggFunc.Args) + } + case *LogicalWindow: + // Statistics of the columns in LogicalWindow.PartitionBy are used in optimizeByShuffle4Window. + // We don't use statistics of the columns in LogicalWindow.OrderBy currently. + for _, item := range x.PartitionBy { + c.addPredicateColumn(item.Col) + } + // Schema change from children to self. + windowColumns := x.GetWindowResultColumns() + for i, col := range windowColumns { + c.updateColMapFromExpressions(col, x.WindowFuncDescs[i].Args) + } + case *LogicalJoin: + c.collectPredicateColumnsForJoin(x) + case *LogicalApply: + c.collectPredicateColumnsForJoin(&x.LogicalJoin) + // Assume statistics of correlated columns are needed. + // Correlated columns can be found in LogicalApply.Children()[0].Schema(). Since we already visit LogicalApply.Children()[0], + // correlated columns must have existed in columnStatsUsageCollector.colMap. + for _, corCols := range x.CorCols { + c.addPredicateColumn(&corCols.Column) + } + case *LogicalSort: + // Assume statistics of all the columns in ByItems are needed. + for _, item := range x.ByItems { + c.addPredicateColumnsFromExpressions([]expression.Expression{item.Expr}) + } + case *LogicalTopN: + // Assume statistics of all the columns in ByItems are needed. + for _, item := range x.ByItems { + c.addPredicateColumnsFromExpressions([]expression.Expression{item.Expr}) + } + case *LogicalUnionAll: + c.collectPredicateColumnsForUnionAll(x) + case *LogicalPartitionUnionAll: + c.collectPredicateColumnsForUnionAll(&x.LogicalUnionAll) + case *LogicalCTE: + // Visit seedPartLogicalPlan and recursivePartLogicalPlan first. + c.collectFromPlan(x.cte.seedPartLogicalPlan) + if x.cte.recursivePartLogicalPlan != nil { + c.collectFromPlan(x.cte.recursivePartLogicalPlan) + } + // Schema change from seedPlan/recursivePlan to self. + columns := x.Schema().Columns + seedColumns := x.cte.seedPartLogicalPlan.Schema().Columns + var recursiveColumns []*expression.Column + if x.cte.recursivePartLogicalPlan != nil { + recursiveColumns = x.cte.recursivePartLogicalPlan.Schema().Columns + } + relatedCols := make([]*expression.Column, 0, 2) + for i, col := range columns { + relatedCols = append(relatedCols[:0], seedColumns[i]) + if recursiveColumns != nil { + relatedCols = append(relatedCols, recursiveColumns[i]) + } + c.updateColMap(col, relatedCols) + } + // If IsDistinct is true, then we use getColsNDV to calculate row count(see (*LogicalCTE).DeriveStat). In this case + // statistics of all the columns are needed. + if x.cte.IsDistinct { + for _, col := range columns { + c.addPredicateColumn(col) + } + } + case *LogicalCTETable: + // Schema change from seedPlan to self. + for i, col := range x.Schema().Columns { + c.updateColMap(col, []*expression.Column{x.seedSchema.Columns[i]}) } } - case *LogicalCTETable: - // Schema change from seedPlan to self. - for i, col := range x.Schema().Columns { - c.updateColMap(col, []*expression.Column{x.seedSchema.Columns[i]}) + } + if c.collectMode&collectHistNeededColumns != 0 { + // Histogram-needed columns are the columns which occur in the conditions pushed down to DataSource. + // We don't consider LogicalCTE because seedLogicalPlan and recursiveLogicalPlan haven't got logical optimization + // yet(seedLogicalPlan and recursiveLogicalPlan are optimized in DeriveStats phase). Without logical optimization, + // there is no condition pushed down to DataSource so no histogram-needed column can be collected. + switch x := lp.(type) { + case *DataSource: + c.addHistNeededColumns(x) + case *LogicalIndexScan: + c.addHistNeededColumns(x.Source) + case *LogicalTableScan: + c.addHistNeededColumns(x.Source) } } } -// CollectPredicateColumnsForTest collects predicate columns from logical plan. It is only for test. -func CollectPredicateColumnsForTest(lp LogicalPlan) []model.TableColumnID { - collector := newPredicateColumnCollector() +// CollectColumnStatsUsage collects column stats usage from logical plan. +// The first return value is predicate columns and the second return value is histogram-needed columns. +func CollectColumnStatsUsage(lp LogicalPlan) ([]model.TableColumnID, []model.TableColumnID) { + collector := newColumnStatsUsageCollector(collectPredicateColumns | collectHistNeededColumns) collector.collectFromPlan(lp) - tblColIDs := make([]model.TableColumnID, 0, len(collector.predicateCols)) - for tblColID := range collector.predicateCols { - tblColIDs = append(tblColIDs, tblColID) + set2slice := func(set map[model.TableColumnID]struct{}) []model.TableColumnID { + ret := make([]model.TableColumnID, 0, len(set)) + for tblColID := range set { + ret = append(ret, tblColID) + } + return ret } - return tblColIDs + return set2slice(collector.predicateCols), set2slice(collector.histNeededCols) } diff --git a/planner/core/collect_column_stats_usage_test.go b/planner/core/collect_column_stats_usage_test.go index b270b6f7c1bfc..5cc64a80e831e 100644 --- a/planner/core/collect_column_stats_usage_test.go +++ b/planner/core/collect_column_stats_usage_test.go @@ -12,210 +12,335 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core_test +package core import ( "context" - "fmt" - "testing" + "sort" + . "github.com/pingcap/check" + "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/parser/model" - plannercore "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/hint" - "github.com/pingcap/tidb/util/logutil" - "github.com/stretchr/testify/require" + "github.com/pingcap/tidb/util/testleak" ) -func TestCollectPredicateColumns(t *testing.T) { - store, dom, clean := testkit.CreateMockStoreAndDomain(t) - defer clean() - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1, t2") - tk.MustExec("set @@session.tidb_partition_prune_mode = 'static'") - tk.MustExec("create table t1(a int, b int, c int)") - tk.MustExec("create table t2(a int, b int, c int)") - tk.MustExec("create table t3(a int, b int, c int) partition by range(a) (partition p0 values less than (10), partition p1 values less than (20), partition p2 values less than maxvalue)") +func getColumnName(c *C, is infoschema.InfoSchema, tblColID model.TableColumnID, comment CommentInterface) (string, bool) { + var tblInfo *model.TableInfo + var prefix string + if tbl, ok := is.TableByID(tblColID.TableID); ok { + tblInfo = tbl.Meta() + prefix = tblInfo.Name.L + "." + } else { + db, exists := is.SchemaByName(model.NewCIStr("test")) + c.Assert(exists, IsTrue, comment) + for _, tbl := range db.Tables { + pi := tbl.GetPartitionInfo() + if pi == nil { + continue + } + for _, def := range pi.Definitions { + if def.ID == tblColID.TableID { + tblInfo = tbl + prefix = tbl.Name.L + "." + def.Name.L + "." + break + } + } + if tblInfo != nil { + break + } + } + if tblInfo == nil { + return "", false + } + } + for _, col := range tblInfo.Columns { + if tblColID.ColumnID == col.ID { + return prefix + col.Name.L, true + } + } + return "", false +} + +func checkColumnStatsUsage(c *C, is infoschema.InfoSchema, lp LogicalPlan, onlyHistNeeded bool, expected []string, comment CommentInterface) { + var tblColIDs []model.TableColumnID + if onlyHistNeeded { + _, tblColIDs = CollectColumnStatsUsage(lp) + } else { + tblColIDs, _ = CollectColumnStatsUsage(lp) + } + cols := make([]string, 0, len(tblColIDs)) + for _, tblColID := range tblColIDs { + col, ok := getColumnName(c, is, tblColID, comment) + c.Assert(ok, IsTrue, comment) + cols = append(cols, col) + } + sort.Strings(cols) + c.Assert(cols, DeepEquals, expected, comment) +} +func (s *testPlanSuite) TestCollectPredicateColumns(c *C) { + defer testleak.AfterTest(c)() tests := []struct { - sql string - res []string + pruneMode string + sql string + res []string }{ { // DataSource - sql: "select * from t1 where a > 2", - res: []string{"t1.a"}, + sql: "select * from t where a > 2", + res: []string{"t.a"}, }, { // DataSource - sql: "select * from t1 where b in (2, 5) or c = 5", - res: []string{"t1.b", "t1.c"}, + sql: "select * from t where b in (2, 5) or c = 5", + res: []string{"t.b", "t.c"}, }, { // LogicalProjection - sql: "select * from (select a + b as ab, c from t1) as tmp where ab > 4", - res: []string{"t1.a", "t1.b"}, + sql: "select * from (select a + b as ab, c from t) as tmp where ab > 4", + res: []string{"t.a", "t.b"}, }, { // LogicalAggregation - sql: "select b, count(*) from t1 group by b", - res: []string{"t1.b"}, + sql: "select b, count(*) from t group by b", + res: []string{"t.b"}, }, { // LogicalAggregation - sql: "select b, sum(a) from t1 group by b having sum(a) > 3", - res: []string{"t1.a", "t1.b"}, + sql: "select b, sum(a) from t group by b having sum(a) > 3", + res: []string{"t.a", "t.b"}, }, { // LogicalAggregation - sql: "select count(*), sum(a), sum(c) from t1", + sql: "select count(*), sum(a), sum(c) from t", res: []string{}, }, { // LogicalAggregation - sql: "(select a, b from t1) union (select a, c from t2)", - res: []string{"t1.a", "t1.b", "t2.a", "t2.c"}, + sql: "(select a, c from t) union (select a, b from t2)", + res: []string{"t.a", "t.c", "t2.a", "t2.b"}, }, { // LogicalWindow - sql: "select avg(b) over(partition by a) from t1", - res: []string{"t1.a"}, + sql: "select avg(b) over(partition by a) from t", + res: []string{"t.a"}, }, { // LogicalWindow - sql: "select * from (select avg(b) over(partition by a) as w from t1) as tmp where w > 4", - res: []string{"t1.a", "t1.b"}, + sql: "select * from (select avg(b) over(partition by a) as w from t) as tmp where w > 4", + res: []string{"t.a", "t.b"}, }, { // LogicalWindow - sql: "select row_number() over(partition by a order by c) from t1", - res: []string{"t1.a"}, + sql: "select row_number() over(partition by a order by c) from t", + res: []string{"t.a"}, }, { // LogicalJoin - sql: "select * from t1, t2 where t1.a = t2.a", - res: []string{"t1.a", "t2.a"}, + sql: "select * from t, t2 where t.a = t2.a", + res: []string{"t.a", "t2.a"}, }, { // LogicalJoin - sql: "select * from t1 as x join t2 as y on x.b + y.c > 2", - res: []string{"t1.b", "t2.c"}, + sql: "select * from t as x join t2 as y on x.c + y.b > 2", + res: []string{"t.c", "t2.b"}, }, { // LogicalJoin - sql: "select * from t1 as x join t2 as y on x.a = y.a and x.b < 3 and y.c > 2", - res: []string{"t1.a", "t1.b", "t2.a", "t2.c"}, + sql: "select * from t as x join t2 as y on x.a = y.a and x.c < 3 and y.b > 2", + res: []string{"t.a", "t.c", "t2.a", "t2.b"}, }, { // LogicalJoin - sql: "select x.b, y.c, sum(x.c), sum(y.b) from t1 as x join t2 as y on x.a = y.a group by x.b, y.c order by x.b", - res: []string{"t1.a", "t1.b", "t2.a", "t2.c"}, + sql: "select x.c, y.b, sum(x.b), sum(y.a) from t as x join t2 as y on x.a < y.a group by x.c, y.b order by x.c", + res: []string{"t.a", "t.c", "t2.a", "t2.b"}, }, { - // LogicalApply - sql: "select * from t1 where t1.b > all(select b from t2 where t2.c > 2)", - res: []string{"t1.b", "t2.b", "t2.c"}, + // LogicalApply, LogicalJoin + sql: "select * from t2 where t2.b > all(select b from t where t.c > 2)", + res: []string{"t.b", "t.c", "t2.b"}, }, { - // LogicalApply - sql: "select * from t1 where t1.b > (select count(b) from t2 where t2.c > t1.a)", - res: []string{"t1.a", "t1.b", "t2.b", "t2.c"}, + // LogicalApply, LogicalJoin + sql: "select * from t2 where t2.b > any(select b from t where t.c > 2)", + res: []string{"t.b", "t.c", "t2.b"}, + }, + { + // LogicalApply, LogicalJoin + sql: "select * from t2 where t2.b > (select sum(b) from t where t.c > t2.a)", + res: []string{"t.b", "t.c", "t2.a", "t2.b"}, }, { // LogicalApply - sql: "select * from t1 where t1.b > (select count(*) from t2 where t2.c > t1.a)", - res: []string{"t1.a", "t1.b", "t2.c"}, + sql: "select * from t2 where t2.b > (select count(*) from t where t.a > t2.a)", + res: []string{"t.a", "t2.a", "t2.b"}, + }, + { + // LogicalApply, LogicalJoin + sql: "select * from t2 where exists (select * from t where t.a > t2.b)", + res: []string{"t.a", "t2.b"}, + }, + { + // LogicalApply, LogicalJoin + sql: "select * from t2 where not exists (select * from t where t.a > t2.b)", + res: []string{"t.a", "t2.b"}, + }, + { + // LogicalJoin + sql: "select * from t2 where t2.a in (select b from t)", + res: []string{"t.b", "t2.a"}, + }, + { + // LogicalApply, LogicalJoin + sql: "select * from t2 where t2.a not in (select b from t)", + res: []string{"t.b", "t2.a"}, }, { // LogicalSort - sql: "select * from t1 order by c", - res: []string{"t1.c"}, + sql: "select * from t order by c", + res: []string{"t.c"}, }, { // LogicalTopN - sql: "select * from t1 order by a + b limit 10", - res: []string{"t1.a", "t1.b"}, + sql: "select * from t order by a + b limit 10", + res: []string{"t.a", "t.b"}, }, { // LogicalUnionAll - sql: "select * from ((select a, b from t1) union all (select a, c from t2)) as tmp where tmp.b > 2", - res: []string{"t1.b", "t2.c"}, - }, - { - // LogicalPartitionUnionAll - sql: "select * from t3 where a < 15 and b > 1", - res: []string{"t3.a", "t3.b"}, + sql: "select * from ((select a, c from t) union all (select a, b from t2)) as tmp where tmp.c > 2", + res: []string{"t.c", "t2.b"}, }, { // LogicalCTE - sql: "with cte(x, y) as (select a + 1, b from t1 where b > 1) select * from cte where x > 3", - res: []string{"t1.a", "t1.b"}, + sql: "with cte(x, y) as (select a + 1, b from t where b > 1) select * from cte where x > 3", + res: []string{"t.a", "t.b"}, }, { // LogicalCTE, LogicalCTETable - sql: "with recursive cte(x, y) as (select c, 1 from t1 union all select x + 1, y from cte where x < 5) select * from cte", - res: []string{"t1.c"}, + sql: "with recursive cte(x, y) as (select c, 1 from t union all select x + 1, y from cte where x < 5) select * from cte", + res: []string{"t.c"}, }, { // LogicalCTE, LogicalCTETable - sql: "with recursive cte(x, y) as (select 1, c from t1 union all select x + 1, y from cte where x < 5) select * from cte where y > 1", - res: []string{"t1.c"}, + sql: "with recursive cte(x, y) as (select 1, c from t union all select x + 1, y from cte where x < 5) select * from cte where y > 1", + res: []string{"t.c"}, }, { // LogicalCTE, LogicalCTETable - sql: "with recursive cte(x, y) as (select a, b from t1 union select x + 1, y from cte where x < 5) select * from cte", - res: []string{"t1.a", "t1.b"}, + sql: "with recursive cte(x, y) as (select a, b from t union select x + 1, y from cte where x < 5) select * from cte", + res: []string{"t.a", "t.b"}, + }, + { + // LogicalPartitionUnionAll, static partition prune mode, use table ID rather than partition ID + pruneMode: "static", + sql: "select * from pt1 where ptn < 20 and b > 1", + res: []string{"pt1.b", "pt1.ptn"}, + }, + { + // dynamic partition prune mode, use table ID rather than partition ID + pruneMode: "dynamic", + sql: "select * from pt1 where ptn < 20 and b > 1", + res: []string{"pt1.b", "pt1.ptn"}, }, } ctx := context.Background() - sctx := tk.Session() - is := dom.InfoSchema() - getColName := func(tblColID model.TableColumnID) (string, bool) { - tbl, ok := is.TableByID(tblColID.TableID) - if !ok { - return "", false - } - tblInfo := tbl.Meta() - for _, col := range tblInfo.Columns { - if tblColID.ColumnID == col.ID { - return tblInfo.Name.L + "." + col.Name.L, true - } + for _, tt := range tests { + comment := Commentf("for %s", tt.sql) + if len(tt.pruneMode) > 0 { + s.ctx.GetSessionVars().PartitionPruneMode.Store(tt.pruneMode) } - return "", false + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + c.Assert(err, IsNil, comment) + builder, _ := NewPlanBuilder().Init(s.ctx, s.is, &hint.BlockHintProcessor{}) + p, err := builder.Build(ctx, stmt) + c.Assert(err, IsNil, comment) + lp, ok := p.(LogicalPlan) + c.Assert(ok, IsTrue, comment) + // We check predicate columns twice, before and after logical optimization. Some logical plan patterns may occur before + // logical optimization while others may occur after logical optimization. + checkColumnStatsUsage(c, s.is, lp, false, tt.res, comment) + lp, err = logicalOptimize(ctx, builder.GetOptFlag(), lp) + c.Assert(err, IsNil, comment) + checkColumnStatsUsage(c, s.is, lp, false, tt.res, comment) } - checkPredicateColumns := func(lp plannercore.LogicalPlan, expected []string, comment string) { - tblColIDs := plannercore.CollectPredicateColumnsForTest(lp) - cols := make([]string, 0, len(tblColIDs)) - for _, tblColID := range tblColIDs { - col, ok := getColName(tblColID) - require.True(t, ok, comment) - cols = append(cols, col) - } - require.ElementsMatch(t, expected, cols, comment) +} + +func (s *testPlanSuite) TestCollectHistNeededColumns(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + pruneMode string + sql string + res []string + }{ + { + sql: "select * from t where a > 2", + res: []string{"t.a"}, + }, + { + sql: "select * from t where b in (2, 5) or c = 5", + res: []string{"t.b", "t.c"}, + }, + { + sql: "select * from t where a + b > 1", + res: []string{"t.a", "t.b"}, + }, + { + sql: "select b, count(a) from t where b > 1 group by b having count(a) > 2", + res: []string{"t.b"}, + }, + { + sql: "select * from t as x join t2 as y on x.b + y.b > 2 and x.c > 1 and y.a < 1", + res: []string{"t.c", "t2.a"}, + }, + { + sql: "select * from t2 where t2.b > all(select b from t where t.c > 2)", + res: []string{"t.c"}, + }, + { + sql: "select * from t2 where t2.b > any(select b from t where t.c > 2)", + res: []string{"t.c"}, + }, + { + sql: "select * from t2 where t2.b in (select b from t where t.c > 2)", + res: []string{"t.c"}, + }, + { + pruneMode: "static", + sql: "select * from pt1 where ptn < 20 and b > 1", + res: []string{"pt1.p1.b", "pt1.p1.ptn", "pt1.p2.b", "pt1.p2.ptn"}, + }, + { + pruneMode: "dynamic", + sql: "select * from pt1 where ptn < 20 and b > 1", + res: []string{"pt1.b", "pt1.ptn"}, + }, } + ctx := context.Background() for _, tt := range tests { - comment := fmt.Sprintf("for %s", tt.sql) - logutil.BgLogger().Info(comment) - stmts, err := tk.Session().Parse(ctx, tt.sql) - require.NoError(t, err, comment) - stmt := stmts[0] - err = plannercore.Preprocess(sctx, stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is})) - require.NoError(t, err, comment) - builder, _ := plannercore.NewPlanBuilder().Init(sctx, is, &hint.BlockHintProcessor{}) + comment := Commentf("for %s", tt.sql) + if len(tt.pruneMode) > 0 { + s.ctx.GetSessionVars().PartitionPruneMode.Store(tt.pruneMode) + } + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + c.Assert(err, IsNil, comment) + builder, _ := NewPlanBuilder().Init(s.ctx, s.is, &hint.BlockHintProcessor{}) p, err := builder.Build(ctx, stmt) - require.NoError(t, err, comment) - lp, ok := p.(plannercore.LogicalPlan) - require.True(t, ok, comment) - // We check predicate columns twice, before and after logical optimization. Some logical plan patterns may occur before - // logical optimization while others may occur after logical optimization. - // logutil.BgLogger().Info("before logical opt", zap.String("lp", plannercore.ToString(lp))) - checkPredicateColumns(lp, tt.res, comment) - lp, err = plannercore.LogicalOptimize(ctx, builder.GetOptFlag(), lp) - require.NoError(t, err, comment) - // logutil.BgLogger().Info("after logical opt", zap.String("lp", plannercore.ToString(lp))) - checkPredicateColumns(lp, tt.res, comment) + c.Assert(err, IsNil, comment) + lp, ok := p.(LogicalPlan) + c.Assert(ok, IsTrue, comment) + flags := builder.GetOptFlag() + // JoinReOrder may need columns stats so collecting hist-needed columns must happen before JoinReOrder. + // Hence we disable JoinReOrder and PruneColumnsAgain here. + flags &= ^(flagJoinReOrder | flagPrunColumnsAgain) + lp, err = logicalOptimize(ctx, flags, lp) + c.Assert(err, IsNil, comment) + checkColumnStatsUsage(c, s.is, lp, true, tt.res, comment) } } diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index e381da64fcdb6..0136545eff430 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -59,8 +59,22 @@ type testPlanSuite struct { } func (s *testPlanSuite) SetUpSuite(c *C) { - s.is = infoschema.MockInfoSchema([]*model.TableInfo{MockSignedTable(), MockUnsignedTable(), MockView(), MockNoPKTable(), - MockRangePartitionTable(), MockHashPartitionTable(), MockListPartitionTable()}) + tblInfos := []*model.TableInfo{MockSignedTable(), MockUnsignedTable(), MockView(), MockNoPKTable(), + MockRangePartitionTable(), MockHashPartitionTable(), MockListPartitionTable()} + id := int64(0) + for _, tblInfo := range tblInfos { + tblInfo.ID = id + id += 1 + pi := tblInfo.GetPartitionInfo() + if pi == nil { + continue + } + for _, def := range pi.Definitions { + def.ID = id + id += 1 + } + } + s.is = infoschema.MockInfoSchema(tblInfos) s.ctx = MockContext() domain.GetDomain(s.ctx).MockInfoCacheAndLoadInfoSchema(s.is) s.ctx.GetSessionVars().EnableWindowFunction = true diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 1b0f6c4543985..7fa43ab3ce80b 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -1307,7 +1307,7 @@ type LogicalCTETable struct { name string idForStorage int - // seedSchema is only used in predicateColumnCollector to get column mapping + // seedSchema is only used in columnStatsUsageCollector to get column mapping seedSchema *expression.Schema }