diff --git a/executor/test/tiflashtest/BUILD.bazel b/executor/test/tiflashtest/BUILD.bazel index 908ea69b2c10f..2f6658a8fa74f 100644 --- a/executor/test/tiflashtest/BUILD.bazel +++ b/executor/test/tiflashtest/BUILD.bazel @@ -9,7 +9,7 @@ go_test( ], flaky = True, race = "on", - shard_count = 38, + shard_count = 39, deps = [ "//config", "//domain", diff --git a/executor/test/tiflashtest/tiflash_test.go b/executor/test/tiflashtest/tiflash_test.go index a0a42a639d46d..bfc8d1caf8ff2 100644 --- a/executor/test/tiflashtest/tiflash_test.go +++ b/executor/test/tiflashtest/tiflash_test.go @@ -1236,6 +1236,37 @@ func TestAggPushDownCountStar(t *testing.T) { tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5")) } +func TestAggPushDownUnionAndMPP(t *testing.T) { + store := testkit.CreateMockStore(t, withMockTiFlash(2)) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("alter table t set tiflash replica 1") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("set @@tidb_allow_mpp=1;") + tk.MustExec("set @@tidb_enforce_mpp=1;") + tk.MustExec("set @@tidb_opt_agg_push_down=1") + + tk.MustExec("create table c(c_id int)") + tk.MustExec("create table o(o_id int, c_id int)") + tk.MustExec("insert into c values(1),(1),(1),(1)") + tk.MustExec("insert into o values(1,1),(1,1),(1,2)") + tk.MustExec("alter table c set tiflash replica 1") + tk.MustExec("alter table o set tiflash replica 1") + + tk.MustQuery("select a, count(*) from (select a, b from t " + + "union all " + + "select a, b from t" + + ") t group by a order by a limit 10;").Check(testkit.Rows("1 10")) + + tk.MustQuery("select o.o_id, count(*) from c, o where c.c_id=o.o_id group by o.o_id").Check(testkit.Rows("1 12")) +} + func TestGroupStreamAggOnTiFlash(t *testing.T) { store := testkit.CreateMockStore(t, withMockTiFlash(2)) tk := testkit.NewTestKit(t, store) diff --git a/planner/core/casetest/enforcempp/enforce_mpp_test.go b/planner/core/casetest/enforcempp/enforce_mpp_test.go index 414604003c08d..b0e20e36e01a9 100644 --- a/planner/core/casetest/enforcempp/enforce_mpp_test.go +++ b/planner/core/casetest/enforcempp/enforce_mpp_test.go @@ -354,13 +354,20 @@ func TestMPP2PhaseAggPushDown(t *testing.T) { tk.MustExec("create table c(c_id bigint)") tk.MustExec("create table o(o_id bigint, c_id bigint not null)") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + // Create virtual tiflash replica info. dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) require.True(t, exists) for _, tblInfo := range db.Tables { - if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" { + if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" || tblInfo.Name.L == "t" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ Count: 1, Available: true, diff --git a/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_in.json b/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_in.json index 17c1d8c02e8da..384d9a8c01b9b 100644 --- a/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_in.json +++ b/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_in.json @@ -93,7 +93,8 @@ "set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;", "EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate", "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column", - "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column" + "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column", + "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10" ] }, { diff --git a/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json b/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json index 681f3fd15cbdc..2dc6c93121651 100644 --- a/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json +++ b/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json @@ -681,20 +681,21 @@ "└─ExchangeSender_79 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6", " └─Projection_78 8000.00 mpp[tiflash] Column#6, test.o.o_id", - " └─HashAgg_77 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id", - " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", - " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]", - " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", - " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", - " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id", - " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id", - " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]", - " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9", - " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", - " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", - " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" + " └─HashAgg_77 8000.00 mpp[tiflash] group by:Column#27, funcs:sum(Column#25)->Column#6, funcs:firstrow(Column#26)->test.o.o_id", + " └─Projection_81 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#25, Column#8->Column#26, test.o.o_id->Column#27", + " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", + " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]", + " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", + " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", + " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id", + " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id", + " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]", + " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9", + " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", + " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", + " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" ], "Warn": null }, @@ -705,20 +706,48 @@ "└─ExchangeSender_79 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6", " └─Projection_78 8000.00 mpp[tiflash] Column#6, test.o.c_id", - " └─HashAgg_77 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id", - " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", - " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", - " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", - " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", - " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id", - " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id", - " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", - " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9", - " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", - " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", - " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" + " └─HashAgg_77 8000.00 mpp[tiflash] group by:Column#23, funcs:sum(Column#21)->Column#6, funcs:firstrow(Column#22)->test.o.c_id", + " └─Projection_81 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#21, Column#8->Column#22, test.o.c_id->Column#23", + " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", + " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", + " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", + " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", + " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id", + " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id", + " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", + " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9", + " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", + " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", + " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10", + "Plan": [ + "Projection 10.00 root Column#7, Column#9", + "└─TopN 10.00 root Column#7, offset:0, count:10", + " └─TableReader 10.00 root MppVersion: 2, data:ExchangeSender", + " └─ExchangeSender 10.00 mpp[tiflash] ExchangeType: PassThrough", + " └─TopN 10.00 mpp[tiflash] Column#7, offset:0, count:10", + " └─Projection 16000.00 mpp[tiflash] Column#9, Column#7", + " └─HashAgg 16000.00 mpp[tiflash] group by:Column#38, funcs:sum(Column#36)->Column#9, funcs:firstrow(Column#37)->Column#7", + " └─Projection 16000.00 mpp[tiflash] cast(Column#10, decimal(20,0) BINARY)->Column#36, Column#11->Column#37, Column#7->Column#38", + " └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#7, collate: binary]", + " └─Union 16000.00 mpp[tiflash] ", + " ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#30)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", + " │ └─ExchangeReceiver 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", + " │ └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#30", + " │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#33)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", + " └─ExchangeReceiver 8000.00 mpp[tiflash] ", + " └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#33", + " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index c1034150f75a2..e24f9f291b113 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -3216,6 +3216,16 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert // Is this aggregate a final stage aggregate? // Final agg can't be split into multi-stage aggregate hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode + // count final agg should become sum for MPP execution path. + // In the traditional case, TiDB take up the final agg role and push partial agg to TiKV, + // while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't + finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) { + for i, agg := range aggFuncs { + if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount { + aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false) + } + } + } if len(la.GroupByItems) > 0 { partitionCols := la.GetPotentialPartitionKeys() @@ -3248,6 +3258,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) agg.SetSchema(la.schema.Clone()) agg.MppRunMode = Mpp1Phase + finalAggAdjust(agg.AggFuncs) hashAggs = append(hashAggs, agg) }