diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 89930b175532..eba5ffe3f457 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2043,7 +2043,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]], num_internal_exprs: 1 })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2070,7 +2070,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]], num_internal_exprs: 1 })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; assert_eq!(format!("{rollup:?}"), expected); @@ -2335,7 +2335,7 @@ mod tests { .expect("hash aggregate"); assert_eq!( "sum(aggregate_test_100.c3)", - final_hash_agg.schema().field(2).name() + final_hash_agg.schema().field(3).name() ); // we need access to the input to the partial aggregate so that other projects can // implement serde diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cc8ddf8ec8e8..49f2d53b592e 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -990,9 +990,28 @@ impl LogicalPlanBuilder { let group_expr = add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; - Aggregate::try_new(self.plan, group_expr, aggr_expr) - .map(LogicalPlan::Aggregate) - .map(Self::new) + + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + + let aggregate = Aggregate::try_new(self.plan, group_expr, aggr_expr)?; + let group_expr_len = aggregate.group_expr_len()?; + + let aggregate = LogicalPlan::Aggregate(aggregate); + + if is_grouping_set { + // For grouping sets we do a project to not expose the internal grouping id + let exprs = aggregate + .schema() + .columns() + .into_iter() + .enumerate() + .filter(|(idx, _)| *idx != group_expr_len - 1) + .map(|(_, column)| Expr::Column(column)) + .collect::>(); + Self::new(aggregate).project(exprs) + } else { + Ok(Self::new(aggregate)) + } } /// Create an expression to represent the explanation of the plan diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 443d23804adb..34c933ff1bc2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use super::dml::CopyTo; use super::DdlStatement; @@ -2964,6 +2964,10 @@ impl Aggregate { .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); + qualified_fields.push(( + None, + Field::new(Self::INTERNAL_GROUPING_ID, DataType::UInt8, false).into(), + )); } qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?); @@ -3015,9 +3019,19 @@ impl Aggregate { }) } + fn is_grouping_set(&self) -> bool { + matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) + } + /// Get the output expressions. fn output_expressions(&self) -> Result> { + static INTERNAL_ID_EXPR: OnceLock = OnceLock::new(); let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; + if self.is_grouping_set() { + exprs.push(INTERNAL_ID_EXPR.get_or_init(|| { + Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID)) + })); + } exprs.extend(self.aggr_expr.iter()); debug_assert!(exprs.len() == self.schema.fields().len()); Ok(exprs) @@ -3029,6 +3043,8 @@ impl Aggregate { pub fn group_expr_len(&self) -> Result { grouping_set_expr_count(&self.group_expr) } + + pub const INTERNAL_GROUPING_ID: &str = "__grouping_id"; } // Manual implementation needed because of `schema` field. Comparison excludes this field. diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 1d8eb9445eda..3464f4aa37b9 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -65,7 +65,8 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { "Invalid group by expressions, GroupingSet must be the only expression" ); } - Ok(grouping_set.distinct_expr().len()) + // Groupings sets have an additional interal column for the grouping id + Ok(grouping_set.distinct_expr().len() + 1) } else { Ok(group_expr.len()) } diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index bc1642bf7952..4e352e25b52c 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -135,7 +135,7 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { // Compare output expressions of the partial, and input expressions of the final operator. physical_exprs_equal( - &input_group_by.output_exprs(&AggregateMode::Partial), + &input_group_by.output_exprs(), &final_group_by.input_exprs(), ) && input_group_by.groups() == final_group_by.groups() && input_group_by.null_expr().len() == final_group_by.null_expr().len() diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 124dcdc8d171..470cb1d88a9b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -40,7 +40,7 @@ use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::{ equivalence::{collapse_lex_req, ProjectionMapping}, expressions::Column, @@ -110,8 +110,6 @@ impl AggregateMode { } } -const INTERNAL_GROUPING_ID: &str = "grouping_id"; - /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. @@ -141,10 +139,6 @@ pub struct PhysicalGroupBy { /// expression in null_expr. If `groups[i][j]` is true, then the /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`. groups: Vec>, - // The number of internal expressions that are used to implement grouping - // sets. These output are removed from the final output and not in `expr` - // as they are generated based on the value in `groups` - num_internal_exprs: usize, } impl PhysicalGroupBy { @@ -154,12 +148,10 @@ impl PhysicalGroupBy { null_expr: Vec<(Arc, String)>, groups: Vec>, ) -> Self { - let num_internal_exprs = if !null_expr.is_empty() { 1 } else { 0 }; Self { expr, null_expr, groups, - num_internal_exprs, } } @@ -171,7 +163,6 @@ impl PhysicalGroupBy { expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], - num_internal_exprs: 0, } } @@ -222,20 +213,17 @@ impl PhysicalGroupBy { } /// The number of expressions in the output schema. - fn num_output_exprs(&self, mode: &AggregateMode) -> usize { + fn num_output_exprs(&self) -> usize { let mut num_exprs = self.expr.len(); if !self.is_single() { - num_exprs += self.num_internal_exprs; - } - if *mode != AggregateMode::Partial { - num_exprs -= self.num_internal_exprs; + num_exprs += 1 } num_exprs } /// Return grouping expressions as they occur in the output schema. - pub fn output_exprs(&self, mode: &AggregateMode) -> Vec> { - let num_output_exprs = self.num_output_exprs(mode); + pub fn output_exprs(&self) -> Vec> { + let num_output_exprs = self.num_output_exprs(); let mut output_exprs = Vec::with_capacity(num_output_exprs); output_exprs.extend( self.expr @@ -244,9 +232,11 @@ impl PhysicalGroupBy { .take(num_output_exprs) .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), ); - if !self.is_single() && *mode == AggregateMode::Partial { - output_exprs - .push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _); + if !self.is_single() { + output_exprs.push(Arc::new(Column::new( + Aggregate::INTERNAL_GROUPING_ID, + self.expr.len(), + )) as _); } output_exprs } @@ -256,7 +246,7 @@ impl PhysicalGroupBy { if self.is_single() { self.expr.len() } else { - self.expr.len() + self.num_internal_exprs + self.expr.len() + 1 } } @@ -290,7 +280,7 @@ impl PhysicalGroupBy { } if !self.is_single() { fields.push(Field::new( - INTERNAL_GROUPING_ID, + Aggregate::INTERNAL_GROUPING_ID, self.grouping_id_type(), false, )); @@ -302,35 +292,29 @@ impl PhysicalGroupBy { /// /// This might be different from the `group_fields` that might contain internal expressions that /// should not be part of the output schema. - fn output_fields( - &self, - input_schema: &Schema, - mode: &AggregateMode, - ) -> Result> { + fn output_fields(&self, input_schema: &Schema) -> Result> { let mut fields = self.group_fields(input_schema)?; - fields.truncate(self.num_output_exprs(mode)); + fields.truncate(self.num_output_exprs()); Ok(fields) } /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial /// aggregation. pub fn as_final(&self) -> PhysicalGroupBy { - let expr: Vec<_> = self - .output_exprs(&AggregateMode::Partial) - .into_iter() - .zip( - self.expr - .iter() - .map(|t| t.1.clone()) - .chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())), - ) - .collect(); + let expr: Vec<_> = + self.output_exprs() + .into_iter() + .zip( + self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once( + Aggregate::INTERNAL_GROUPING_ID.to_owned(), + )), + ) + .collect(); let num_exprs = expr.len(); Self { expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], - num_internal_exprs: self.num_internal_exprs, } } } @@ -567,7 +551,7 @@ impl AggregateExec { /// Grouping expressions as they occur in the output schema pub fn output_group_expr(&self) -> Vec> { - self.group_by.output_exprs(&AggregateMode::Partial) + self.group_by.output_exprs() } /// Aggregate expressions @@ -901,9 +885,8 @@ fn create_schema( aggr_expr: &[AggregateFunctionExpr], mode: AggregateMode, ) -> Result { - let mut fields = - Vec::with_capacity(group_by.num_output_exprs(&mode) + aggr_expr.len()); - fields.extend(group_by.output_fields(input_schema, &mode)?); + let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); + fields.extend(group_by.output_fields(input_schema)?); match mode { AggregateMode::Partial => { @@ -1506,49 +1489,49 @@ mod tests { // In spill mode, we test with the limited memory, if the mem usage exceeds, // we trigger the early emit rule, which turns out the partial aggregate result. vec![ - "+---+-----+-------------+-----------------+", - "| a | b | grouping_id | COUNT(1)[count] |", - "+---+-----+-------------+-----------------+", - "| | 1.0 | 2 | 1 |", - "| | 1.0 | 2 | 1 |", - "| | 2.0 | 2 | 1 |", - "| | 2.0 | 2 | 1 |", - "| | 3.0 | 2 | 1 |", - "| | 3.0 | 2 | 1 |", - "| | 4.0 | 2 | 1 |", - "| | 4.0 | 2 | 1 |", - "| 2 | | 1 | 1 |", - "| 2 | | 1 | 1 |", - "| 2 | 1.0 | 0 | 1 |", - "| 2 | 1.0 | 0 | 1 |", - "| 3 | | 1 | 1 |", - "| 3 | | 1 | 2 |", - "| 3 | 2.0 | 0 | 2 |", - "| 3 | 3.0 | 0 | 1 |", - "| 4 | | 1 | 1 |", - "| 4 | | 1 | 2 |", - "| 4 | 3.0 | 0 | 1 |", - "| 4 | 4.0 | 0 | 2 |", - "+---+-----+-------------+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 1 |", + "| | 1.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 3 | | 1 | 1 |", + "| 3 | | 1 | 2 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 1 |", + "| 4 | | 1 | 2 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] } else { vec![ - "+---+-----+-------------+-----------------+", - "| a | b | grouping_id | COUNT(1)[count] |", - "+---+-----+-------------+-----------------+", - "| | 1.0 | 2 | 2 |", - "| | 2.0 | 2 | 2 |", - "| | 3.0 | 2 | 2 |", - "| | 4.0 | 2 | 2 |", - "| 2 | | 1 | 2 |", - "| 2 | 1.0 | 0 | 2 |", - "| 3 | | 1 | 3 |", - "| 3 | 2.0 | 0 | 2 |", - "| 3 | 3.0 | 0 | 1 |", - "| 4 | | 1 | 3 |", - "| 4 | 3.0 | 0 | 1 |", - "| 4 | 4.0 | 0 | 2 |", - "+---+-----+-------------+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] }; assert_batches_sorted_eq!(expected, &result); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 6c1733b14808..b5b49a13afad 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -491,7 +491,7 @@ impl GroupedHashAggregateStream { let (ordering, _) = agg .properties() .equivalence_properties() - .find_longest_permutation(&agg_group_by.output_exprs(&agg.mode)); + .find_longest_permutation(&agg_group_by.output_exprs()); let group_ordering = GroupOrdering::try_new( &group_schema, &agg.input_order_mode, @@ -845,7 +845,7 @@ impl GroupedHashAggregateStream { let mut output = self.group_values.emit(emit_to)?; if !spilling { - output.truncate(self.group_by.num_output_exprs(&self.mode)); + output.truncate(self.group_by.num_output_exprs()); } if let EmitTo::First(n) = emit_to { self.group_ordering.remove_groups(n); diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c93d9e6fc435..a8298b8211c6 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -714,10 +714,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = LogicalPlanBuilder::from(input.clone()) .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; - let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { - &agg.group_expr - } else { - unreachable!(); + + let group_by_exprs = match &plan { + LogicalPlan::Aggregate(agg) => &agg.group_expr, + LogicalPlan::Projection(proj) => match *proj.input { + LogicalPlan::Aggregate(ref agg) => &agg.group_expr.clone(), + _ => unreachable!(), + }, + _ => unreachable!(), }; // in this next section of code we are re-writing the projection to refer to columns diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index b7452dd84cfb..c1652ee3b0af 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4883,16 +4883,18 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -01)Limit: skip=0, fetch=3 -02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -03)----TableScan: aggregate_test_100 projection=[c2, c3] +01)Projection: aggregate_test_100.c2, aggregate_test_100.c3 +02)--Limit: skip=0, fetch=3 +03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=3 -02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, grouping_id@2 as grouping_id], aggr=[], lim=[3] -03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true +01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2 as __grouping_id], aggr=[], lim=[3] +04)------CoalescePartitionsExec +05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true query II SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;