Skip to content

Commit

Permalink
Add grouping_id to the logical plan
Browse files Browse the repository at this point in the history
  • Loading branch information
eejbyfeldt committed Oct 1, 2024
1 parent e82e069 commit 2914017
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 107 deletions.
6 changes: 3 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2043,7 +2043,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]], num_internal_exprs: 1 })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;

assert_eq!(format!("{cube:?}"), expected);

Expand All @@ -2070,7 +2070,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]], num_internal_exprs: 1 })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;

assert_eq!(format!("{rollup:?}"), expected);

Expand Down Expand Up @@ -2335,7 +2335,7 @@ mod tests {
.expect("hash aggregate");
assert_eq!(
"sum(aggregate_test_100.c3)",
final_hash_agg.schema().field(2).name()
final_hash_agg.schema().field(3).name()
);
// we need access to the input to the partial aggregate so that other projects can
// implement serde
Expand Down
25 changes: 22 additions & 3 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,9 +990,28 @@ impl LogicalPlanBuilder {

let group_expr =
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
Aggregate::try_new(self.plan, group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::new)

let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);

let aggregate = Aggregate::try_new(self.plan, group_expr, aggr_expr)?;
let group_expr_len = aggregate.group_expr_len()?;

let aggregate = LogicalPlan::Aggregate(aggregate);

if is_grouping_set {
// For grouping sets we do a project to not expose the internal grouping id
let exprs = aggregate
.schema()
.columns()
.into_iter()
.enumerate()
.filter(|(idx, _)| *idx != group_expr_len - 1)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
Self::new(aggregate).project(exprs)
} else {
Ok(Self::new(aggregate))
}
}

/// Create an expression to represent the explanation of the plan
Expand Down
18 changes: 17 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use super::dml::CopyTo;
use super::DdlStatement;
Expand Down Expand Up @@ -2964,6 +2964,10 @@ impl Aggregate {
.into_iter()
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
.collect::<Vec<_>>();
qualified_fields.push((
None,
Field::new(Self::INTERNAL_GROUPING_ID, DataType::UInt8, false).into(),
));
}

qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
Expand Down Expand Up @@ -3015,9 +3019,19 @@ impl Aggregate {
})
}

fn is_grouping_set(&self) -> bool {
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
}

/// Get the output expressions.
fn output_expressions(&self) -> Result<Vec<&Expr>> {
static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
if self.is_grouping_set() {
exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
}));
}
exprs.extend(self.aggr_expr.iter());
debug_assert!(exprs.len() == self.schema.fields().len());
Ok(exprs)
Expand All @@ -3029,6 +3043,8 @@ impl Aggregate {
pub fn group_expr_len(&self) -> Result<usize> {
grouping_set_expr_count(&self.group_expr)
}

pub const INTERNAL_GROUPING_ID: &str = "__grouping_id";
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
Ok(grouping_set.distinct_expr().len())
// Groupings sets have an additional interal column for the grouping id
Ok(grouping_set.distinct_expr().len() + 1)
} else {
Ok(group_expr.len())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {

// Compare output expressions of the partial, and input expressions of the final operator.
physical_exprs_equal(
&input_group_by.output_exprs(&AggregateMode::Partial),
&input_group_by.output_exprs(),
&final_group_by.input_exprs(),
) && input_group_by.groups() == final_group_by.groups()
&& input_group_by.null_expr().len() == final_group_by.null_expr().len()
Expand Down
149 changes: 66 additions & 83 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use arrow_schema::DataType;
use datafusion_common::stats::Precision;
use datafusion_common::{internal_err, not_impl_err, Result};
use datafusion_execution::TaskContext;
use datafusion_expr::Accumulator;
use datafusion_expr::{Accumulator, Aggregate};
use datafusion_physical_expr::{
equivalence::{collapse_lex_req, ProjectionMapping},
expressions::Column,
Expand Down Expand Up @@ -110,8 +110,6 @@ impl AggregateMode {
}
}

const INTERNAL_GROUPING_ID: &str = "grouping_id";

/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
/// and a single group [false, false].
Expand Down Expand Up @@ -141,10 +139,6 @@ pub struct PhysicalGroupBy {
/// expression in null_expr. If `groups[i][j]` is true, then the
/// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
groups: Vec<Vec<bool>>,
// The number of internal expressions that are used to implement grouping
// sets. These output are removed from the final output and not in `expr`
// as they are generated based on the value in `groups`
num_internal_exprs: usize,
}

impl PhysicalGroupBy {
Expand All @@ -154,12 +148,10 @@ impl PhysicalGroupBy {
null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
groups: Vec<Vec<bool>>,
) -> Self {
let num_internal_exprs = if !null_expr.is_empty() { 1 } else { 0 };
Self {
expr,
null_expr,
groups,
num_internal_exprs,
}
}

Expand All @@ -171,7 +163,6 @@ impl PhysicalGroupBy {
expr,
null_expr: vec![],
groups: vec![vec![false; num_exprs]],
num_internal_exprs: 0,
}
}

Expand Down Expand Up @@ -222,20 +213,17 @@ impl PhysicalGroupBy {
}

/// The number of expressions in the output schema.
fn num_output_exprs(&self, mode: &AggregateMode) -> usize {
fn num_output_exprs(&self) -> usize {
let mut num_exprs = self.expr.len();
if !self.is_single() {
num_exprs += self.num_internal_exprs;
}
if *mode != AggregateMode::Partial {
num_exprs -= self.num_internal_exprs;
num_exprs += 1
}
num_exprs
}

/// Return grouping expressions as they occur in the output schema.
pub fn output_exprs(&self, mode: &AggregateMode) -> Vec<Arc<dyn PhysicalExpr>> {
let num_output_exprs = self.num_output_exprs(mode);
pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
let num_output_exprs = self.num_output_exprs();
let mut output_exprs = Vec::with_capacity(num_output_exprs);
output_exprs.extend(
self.expr
Expand All @@ -244,9 +232,11 @@ impl PhysicalGroupBy {
.take(num_output_exprs)
.map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
);
if !self.is_single() && *mode == AggregateMode::Partial {
output_exprs
.push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _);
if !self.is_single() {
output_exprs.push(Arc::new(Column::new(
Aggregate::INTERNAL_GROUPING_ID,
self.expr.len(),
)) as _);
}
output_exprs
}
Expand All @@ -256,7 +246,7 @@ impl PhysicalGroupBy {
if self.is_single() {
self.expr.len()
} else {
self.expr.len() + self.num_internal_exprs
self.expr.len() + 1
}
}

Expand Down Expand Up @@ -290,7 +280,7 @@ impl PhysicalGroupBy {
}
if !self.is_single() {
fields.push(Field::new(
INTERNAL_GROUPING_ID,
Aggregate::INTERNAL_GROUPING_ID,
self.grouping_id_type(),
false,
));
Expand All @@ -302,35 +292,29 @@ impl PhysicalGroupBy {
///
/// This might be different from the `group_fields` that might contain internal expressions that
/// should not be part of the output schema.
fn output_fields(
&self,
input_schema: &Schema,
mode: &AggregateMode,
) -> Result<Vec<Field>> {
fn output_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
let mut fields = self.group_fields(input_schema)?;
fields.truncate(self.num_output_exprs(mode));
fields.truncate(self.num_output_exprs());
Ok(fields)
}

/// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial
/// aggregation.
pub fn as_final(&self) -> PhysicalGroupBy {
let expr: Vec<_> = self
.output_exprs(&AggregateMode::Partial)
.into_iter()
.zip(
self.expr
.iter()
.map(|t| t.1.clone())
.chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())),
)
.collect();
let expr: Vec<_> =
self.output_exprs()
.into_iter()
.zip(
self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
Aggregate::INTERNAL_GROUPING_ID.to_owned(),
)),
)
.collect();
let num_exprs = expr.len();
Self {
expr,
null_expr: vec![],
groups: vec![vec![false; num_exprs]],
num_internal_exprs: self.num_internal_exprs,
}
}
}
Expand Down Expand Up @@ -567,7 +551,7 @@ impl AggregateExec {

/// Grouping expressions as they occur in the output schema
pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.group_by.output_exprs(&AggregateMode::Partial)
self.group_by.output_exprs()
}

/// Aggregate expressions
Expand Down Expand Up @@ -901,9 +885,8 @@ fn create_schema(
aggr_expr: &[AggregateFunctionExpr],
mode: AggregateMode,
) -> Result<Schema> {
let mut fields =
Vec::with_capacity(group_by.num_output_exprs(&mode) + aggr_expr.len());
fields.extend(group_by.output_fields(input_schema, &mode)?);
let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
fields.extend(group_by.output_fields(input_schema)?);

match mode {
AggregateMode::Partial => {
Expand Down Expand Up @@ -1506,49 +1489,49 @@ mod tests {
// In spill mode, we test with the limited memory, if the mem usage exceeds,
// we trigger the early emit rule, which turns out the partial aggregate result.
vec![
"+---+-----+-------------+-----------------+",
"| a | b | grouping_id | COUNT(1)[count] |",
"+---+-----+-------------+-----------------+",
"| | 1.0 | 2 | 1 |",
"| | 1.0 | 2 | 1 |",
"| | 2.0 | 2 | 1 |",
"| | 2.0 | 2 | 1 |",
"| | 3.0 | 2 | 1 |",
"| | 3.0 | 2 | 1 |",
"| | 4.0 | 2 | 1 |",
"| | 4.0 | 2 | 1 |",
"| 2 | | 1 | 1 |",
"| 2 | | 1 | 1 |",
"| 2 | 1.0 | 0 | 1 |",
"| 2 | 1.0 | 0 | 1 |",
"| 3 | | 1 | 1 |",
"| 3 | | 1 | 2 |",
"| 3 | 2.0 | 0 | 2 |",
"| 3 | 3.0 | 0 | 1 |",
"| 4 | | 1 | 1 |",
"| 4 | | 1 | 2 |",
"| 4 | 3.0 | 0 | 1 |",
"| 4 | 4.0 | 0 | 2 |",
"+---+-----+-------------+-----------------+",
"+---+-----+---------------+-----------------+",
"| a | b | __grouping_id | COUNT(1)[count] |",
"+---+-----+---------------+-----------------+",
"| | 1.0 | 2 | 1 |",
"| | 1.0 | 2 | 1 |",
"| | 2.0 | 2 | 1 |",
"| | 2.0 | 2 | 1 |",
"| | 3.0 | 2 | 1 |",
"| | 3.0 | 2 | 1 |",
"| | 4.0 | 2 | 1 |",
"| | 4.0 | 2 | 1 |",
"| 2 | | 1 | 1 |",
"| 2 | | 1 | 1 |",
"| 2 | 1.0 | 0 | 1 |",
"| 2 | 1.0 | 0 | 1 |",
"| 3 | | 1 | 1 |",
"| 3 | | 1 | 2 |",
"| 3 | 2.0 | 0 | 2 |",
"| 3 | 3.0 | 0 | 1 |",
"| 4 | | 1 | 1 |",
"| 4 | | 1 | 2 |",
"| 4 | 3.0 | 0 | 1 |",
"| 4 | 4.0 | 0 | 2 |",
"+---+-----+---------------+-----------------+",
]
} else {
vec![
"+---+-----+-------------+-----------------+",
"| a | b | grouping_id | COUNT(1)[count] |",
"+---+-----+-------------+-----------------+",
"| | 1.0 | 2 | 2 |",
"| | 2.0 | 2 | 2 |",
"| | 3.0 | 2 | 2 |",
"| | 4.0 | 2 | 2 |",
"| 2 | | 1 | 2 |",
"| 2 | 1.0 | 0 | 2 |",
"| 3 | | 1 | 3 |",
"| 3 | 2.0 | 0 | 2 |",
"| 3 | 3.0 | 0 | 1 |",
"| 4 | | 1 | 3 |",
"| 4 | 3.0 | 0 | 1 |",
"| 4 | 4.0 | 0 | 2 |",
"+---+-----+-------------+-----------------+",
"+---+-----+---------------+-----------------+",
"| a | b | __grouping_id | COUNT(1)[count] |",
"+---+-----+---------------+-----------------+",
"| | 1.0 | 2 | 2 |",
"| | 2.0 | 2 | 2 |",
"| | 3.0 | 2 | 2 |",
"| | 4.0 | 2 | 2 |",
"| 2 | | 1 | 2 |",
"| 2 | 1.0 | 0 | 2 |",
"| 3 | | 1 | 3 |",
"| 3 | 2.0 | 0 | 2 |",
"| 3 | 3.0 | 0 | 1 |",
"| 4 | | 1 | 3 |",
"| 4 | 3.0 | 0 | 1 |",
"| 4 | 4.0 | 0 | 2 |",
"+---+-----+---------------+-----------------+",
]
};
assert_batches_sorted_eq!(expected, &result);
Expand Down
Loading

0 comments on commit 2914017

Please sign in to comment.