Skip to content

Commit

Permalink
Add grouping_id to the logical plan
Browse files Browse the repository at this point in the history
  • Loading branch information
eejbyfeldt committed Oct 1, 2024
1 parent e82e069 commit e9c7e13
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 133 deletions.
17 changes: 17 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,26 @@ impl DataFrame {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let plan = LogicalPlanBuilder::from(self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
// For grouping sets we do a project to not expose the internal grouping id
let exprs = plan
.schema()
.columns()
.into_iter()
.enumerate()
.filter(|(idx, _)| *idx != grouping_id_pos)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
} else {
plan
};
Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2043,7 +2043,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]], num_internal_exprs: 1 })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;

assert_eq!(format!("{cube:?}"), expected);

Expand All @@ -2070,7 +2070,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]], num_internal_exprs: 1 })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;

assert_eq!(format!("{rollup:?}"), expected);

Expand Down Expand Up @@ -2335,7 +2335,7 @@ mod tests {
.expect("hash aggregate");
assert_eq!(
"sum(aggregate_test_100.c3)",
final_hash_agg.schema().field(2).name()
final_hash_agg.schema().field(3).name()
);
// we need access to the input to the partial aggregate so that other projects can
// implement serde
Expand Down
18 changes: 17 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use super::dml::CopyTo;
use super::DdlStatement;
Expand Down Expand Up @@ -2964,6 +2964,10 @@ impl Aggregate {
.into_iter()
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
.collect::<Vec<_>>();
qualified_fields.push((
None,
Field::new(Self::INTERNAL_GROUPING_ID, DataType::UInt8, false).into(),
));
}

qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
Expand Down Expand Up @@ -3015,9 +3019,19 @@ impl Aggregate {
})
}

fn is_grouping_set(&self) -> bool {
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
}

/// Get the output expressions.
fn output_expressions(&self) -> Result<Vec<&Expr>> {
static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
if self.is_grouping_set() {
exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
}));
}
exprs.extend(self.aggr_expr.iter());
debug_assert!(exprs.len() == self.schema.fields().len());
Ok(exprs)
Expand All @@ -3029,6 +3043,8 @@ impl Aggregate {
pub fn group_expr_len(&self) -> Result<usize> {
grouping_set_expr_count(&self.group_expr)
}

pub const INTERNAL_GROUPING_ID: &str = "__grouping_id";
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
Ok(grouping_set.distinct_expr().len())
// Groupings sets have an additional interal column for the grouping id
Ok(grouping_set.distinct_expr().len() + 1)
} else {
Ok(group_expr.len())
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -373,7 +373,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -392,7 +392,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {

// Compare output expressions of the partial, and input expressions of the final operator.
physical_exprs_equal(
&input_group_by.output_exprs(&AggregateMode::Partial),
&input_group_by.output_exprs(),
&final_group_by.input_exprs(),
) && input_group_by.groups() == final_group_by.groups()
&& input_group_by.null_expr().len() == final_group_by.null_expr().len()
Expand Down
Loading

0 comments on commit e9c7e13

Please sign in to comment.