diff --git a/pkg/expression/aggregation/base_func.go b/pkg/expression/aggregation/base_func.go index 3de4f3299d976..39a1b518ea8c7 100644 --- a/pkg/expression/aggregation/base_func.go +++ b/pkg/expression/aggregation/base_func.go @@ -214,6 +214,11 @@ func (a *baseFuncDesc) TypeInfer4AvgSum(avgRetType *types.FieldType) { } } +// TypeInfer4FinalCount infers the type of sum agg which is rewritten from final count agg run on MPP mode. +func (a *baseFuncDesc) TypeInfer4FinalCount(finalCountRetType *types.FieldType) { + a.RetTp = finalCountRetType.Clone() +} + // typeInfer4Avg should returns a "decimal", otherwise it returns a "double". // Because child returns integer or decimal type. func (a *baseFuncDesc) typeInfer4Avg(sessionctx.Context) { diff --git a/pkg/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json b/pkg/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json index 2dc6c93121651..7dc03ab5a1154 100644 --- a/pkg/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json +++ b/pkg/planner/core/casetest/enforcempp/testdata/enforce_mpp_suite_out.json @@ -681,21 +681,20 @@ "└─ExchangeSender_79 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6", " └─Projection_78 8000.00 mpp[tiflash] Column#6, test.o.o_id", - " └─HashAgg_77 8000.00 mpp[tiflash] group by:Column#27, funcs:sum(Column#25)->Column#6, funcs:firstrow(Column#26)->test.o.o_id", - " └─Projection_81 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#25, Column#8->Column#26, test.o.o_id->Column#27", - " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", - " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]", - " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", - " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", - " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id", - " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id", - " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]", - " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9", - " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", - " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", - " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" + " └─HashAgg_77 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:sum(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id", + " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", + " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]", + " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", + " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", + " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id", + " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id", + " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]", + " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9", + " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", + " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", + " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" ], "Warn": null }, @@ -706,21 +705,20 @@ "└─ExchangeSender_79 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6", " └─Projection_78 8000.00 mpp[tiflash] Column#6, test.o.c_id", - " └─HashAgg_77 8000.00 mpp[tiflash] group by:Column#23, funcs:sum(Column#21)->Column#6, funcs:firstrow(Column#22)->test.o.c_id", - " └─Projection_81 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#21, Column#8->Column#22, test.o.c_id->Column#23", - " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", - " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", - " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", - " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", - " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id", - " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id", - " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", - " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9", - " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", - " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", - " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" + " └─HashAgg_77 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id", + " └─ExchangeReceiver_73 9990.00 mpp[tiflash] ", + " └─ExchangeSender_72 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", + " └─HashJoin_71 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", + " ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", + " │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id", + " │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id", + " │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]", + " │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9", + " │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", + " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", + " └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo" ], "Warn": null }, @@ -733,21 +731,20 @@ " └─ExchangeSender 10.00 mpp[tiflash] ExchangeType: PassThrough", " └─TopN 10.00 mpp[tiflash] Column#7, offset:0, count:10", " └─Projection 16000.00 mpp[tiflash] Column#9, Column#7", - " └─HashAgg 16000.00 mpp[tiflash] group by:Column#38, funcs:sum(Column#36)->Column#9, funcs:firstrow(Column#37)->Column#7", - " └─Projection 16000.00 mpp[tiflash] cast(Column#10, decimal(20,0) BINARY)->Column#36, Column#11->Column#37, Column#7->Column#38", - " └─ExchangeReceiver 16000.00 mpp[tiflash] ", - " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#7, collate: binary]", - " └─Union 16000.00 mpp[tiflash] ", - " ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#30)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", - " │ └─ExchangeReceiver 8000.00 mpp[tiflash] ", - " │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", - " │ └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#30", - " │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", - " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#33)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", - " └─ExchangeReceiver 8000.00 mpp[tiflash] ", - " └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", - " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#33", - " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + " └─HashAgg 16000.00 mpp[tiflash] group by:Column#7, funcs:sum(Column#10)->Column#9, funcs:firstrow(Column#11)->Column#7", + " └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#7, collate: binary]", + " └─Union 16000.00 mpp[tiflash] ", + " ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#30)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", + " │ └─ExchangeReceiver 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", + " │ └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#30", + " │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#33)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", + " └─ExchangeReceiver 8000.00 mpp[tiflash] ", + " └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#33", + " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null } diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index 7601111e36e5c..832c8277f445e 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -3244,7 +3244,9 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) { for i, agg := range aggFuncs { if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount { + oldFT := agg.RetTp aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false) + aggFuncs[i].TypeInfer4FinalCount(oldFT) } } }