Skip to content

Commit

Permalink
planner: fix wrong result when pushing Agg down through Union in MPP …
Browse files Browse the repository at this point in the history
…plans (#46310)
  • Loading branch information
AilinKid authored Aug 30, 2023
1 parent 72b9fb2 commit 4facc13
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 31 deletions.
2 changes: 1 addition & 1 deletion executor/test/tiflashtest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ go_test(
],
flaky = True,
race = "on",
shard_count = 38,
shard_count = 39,
deps = [
"//config",
"//domain",
Expand Down
31 changes: 31 additions & 0 deletions executor/test/tiflashtest/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,37 @@ func TestAggPushDownCountStar(t *testing.T) {
tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5"))
}

func TestAggPushDownUnionAndMPP(t *testing.T) {
store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("alter table t set tiflash replica 1")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("set @@tidb_allow_mpp=1;")
tk.MustExec("set @@tidb_enforce_mpp=1;")
tk.MustExec("set @@tidb_opt_agg_push_down=1")

tk.MustExec("create table c(c_id int)")
tk.MustExec("create table o(o_id int, c_id int)")
tk.MustExec("insert into c values(1),(1),(1),(1)")
tk.MustExec("insert into o values(1,1),(1,1),(1,2)")
tk.MustExec("alter table c set tiflash replica 1")
tk.MustExec("alter table o set tiflash replica 1")

tk.MustQuery("select a, count(*) from (select a, b from t " +
"union all " +
"select a, b from t" +
") t group by a order by a limit 10;").Check(testkit.Rows("1 10"))

tk.MustQuery("select o.o_id, count(*) from c, o where c.c_id=o.o_id group by o.o_id").Check(testkit.Rows("1 12"))
}

func TestGroupStreamAggOnTiFlash(t *testing.T) {
store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
Expand Down
9 changes: 8 additions & 1 deletion planner/core/casetest/enforcempp/enforce_mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,20 @@ func TestMPP2PhaseAggPushDown(t *testing.T) {
tk.MustExec("create table c(c_id bigint)")
tk.MustExec("create table o(o_id bigint, c_id bigint not null)")

tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" {
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" || tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@
"set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;",
"EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate",
"EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column"
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
"EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,20 +681,21 @@
"└─ExchangeSender_79 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6",
" └─Projection_78 8000.00 mpp[tiflash] Column#6, test.o.o_id",
" └─HashAgg_77 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id",
" └─ExchangeReceiver_73 9990.00 mpp[tiflash] ",
" └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]",
" └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
" └─HashAgg_77 8000.00 mpp[tiflash] group by:Column#27, funcs:sum(Column#25)->Column#6, funcs:firstrow(Column#26)->test.o.o_id",
" └─Projection_81 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#25, Column#8->Column#26, test.o.o_id->Column#27",
" └─ExchangeReceiver_73 9990.00 mpp[tiflash] ",
" └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]",
" └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
],
"Warn": null
},
Expand All @@ -705,20 +706,48 @@
"└─ExchangeSender_79 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6",
" └─Projection_78 8000.00 mpp[tiflash] Column#6, test.o.c_id",
" └─HashAgg_77 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id",
" └─ExchangeReceiver_73 9990.00 mpp[tiflash] ",
" └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
" └─HashAgg_77 8000.00 mpp[tiflash] group by:Column#23, funcs:sum(Column#21)->Column#6, funcs:firstrow(Column#22)->test.o.c_id",
" └─Projection_81 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#21, Column#8->Column#22, test.o.c_id->Column#23",
" └─ExchangeReceiver_73 9990.00 mpp[tiflash] ",
" └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
],
"Warn": null
},
{
"SQL": "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10",
"Plan": [
"Projection 10.00 root Column#7, Column#9",
"└─TopN 10.00 root Column#7, offset:0, count:10",
" └─TableReader 10.00 root MppVersion: 2, data:ExchangeSender",
" └─ExchangeSender 10.00 mpp[tiflash] ExchangeType: PassThrough",
" └─TopN 10.00 mpp[tiflash] Column#7, offset:0, count:10",
" └─Projection 16000.00 mpp[tiflash] Column#9, Column#7",
" └─HashAgg 16000.00 mpp[tiflash] group by:Column#38, funcs:sum(Column#36)->Column#9, funcs:firstrow(Column#37)->Column#7",
" └─Projection 16000.00 mpp[tiflash] cast(Column#10, decimal(20,0) BINARY)->Column#36, Column#11->Column#37, Column#7->Column#38",
" └─ExchangeReceiver 16000.00 mpp[tiflash] ",
" └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#7, collate: binary]",
" └─Union 16000.00 mpp[tiflash] ",
" ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#30)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
" │ └─ExchangeReceiver 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]",
" │ └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#30",
" │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo",
" └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#33)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
" └─ExchangeReceiver 8000.00 mpp[tiflash] ",
" └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]",
" └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#33",
" └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo"
],
"Warn": null
}
Expand Down
11 changes: 11 additions & 0 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -3216,6 +3216,16 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
// Is this aggregate a final stage aggregate?
// Final agg can't be split into multi-stage aggregate
hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode
// count final agg should become sum for MPP execution path.
// In the traditional case, TiDB take up the final agg role and push partial agg to TiKV,
// while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't
finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) {
for i, agg := range aggFuncs {
if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount {
aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false)
}
}
}

if len(la.GroupByItems) > 0 {
partitionCols := la.GetPotentialPartitionKeys()
Expand Down Expand Up @@ -3248,6 +3258,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp)
agg.SetSchema(la.schema.Clone())
agg.MppRunMode = Mpp1Phase
finalAggAdjust(agg.AggFuncs)
hashAggs = append(hashAggs, agg)
}

Expand Down

0 comments on commit 4facc13

Please sign in to comment.