Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GROUPING aggregate function (following Postgres behavior.) #12565

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions datafusion/expr-common/src/groups_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ pub trait GroupsAccumulator: Send {
false
}

/// Update this accumulator's groupings. Used for aggregates that
/// report data about the grouping strategy e.g. GROUPING.
///
/// * `group_indices`: Indices of groups in the current grouping set
///
/// * `group_mask`: Mask for the current grouping set (true means null/aggregated)
///
/// * `total_num_groups`: the number of groups (the largest
/// group_index is thus `total_num_groups - 1`).
fn update_groupings(
&mut self,
_group_indices: &[usize],
_group_mask: &[bool],
_total_num_groups: usize,
) -> Result<()> {
Ok(())
}

/// Amount of memory used to store the state of this accumulator,
/// in bytes.
///
Expand Down
199 changes: 192 additions & 7 deletions datafusion/functions-aggregate/src/grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,34 @@

use std::any::Any;
use std::fmt;
use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::array::AsArray;
use arrow::array::BooleanArray;
use arrow::array::UInt32Array;
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use arrow::datatypes::UInt32Type;
use datafusion_common::internal_datafusion_err;
use datafusion_common::internal_err;
use datafusion_common::plan_err;
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::EmitTo;
use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;

make_udaf_expr_and_func!(
Grouping,
grouping,
expression,
"Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.",
"Returns a bitmap where bit i is 1 if this row is aggregated across the ith argument to GROUPING and 0 otherwise.",
grouping_udaf
);

Expand All @@ -59,9 +73,55 @@ impl Grouping {
/// Create a new GROUPING aggregate function.
pub fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
signature: Signature::variadic_any(Volatility::Immutable),
}
}

/// Create an accumulator for GROUPING(grouping_args) in a GROUP BY over group_exprs
/// A special creation function is necessary because GROUPING has unusual input requirements.
pub fn create_grouping_accumulator(
&self,
grouping_args: &[Arc<dyn PhysicalExpr>],
group_exprs: &[(Arc<dyn PhysicalExpr>, String)],
) -> Result<Box<dyn GroupsAccumulator>> {
if grouping_args.len() > 32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets have it as a const

return plan_err!(
"GROUPING is supported for up to 32 columns. Consider another \
GROUPING statement if you need to aggregate over more columns."
);
}
// The PhysicalExprs of grouping_exprs must be Column PhysicalExpr. Because if
// the group by PhysicalExpr in SQL is non-Column PhysicalExpr, then there is
// a ProjectionExec before AggregateExec to convert the non-column PhysicalExpr
// to Column PhysicalExpr.
let column_index =
|expr: &Arc<dyn PhysicalExpr>| match expr.as_any().downcast_ref::<Column>() {
Some(column) => Ok(column.index()),
None => internal_err!("Grouping doesn't support expr: {}", expr),
};
Comment on lines +93 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only true when one enabled the optimizer rule CommonSubexprEliminate . Does not seems like a acceptable to depend on optimizer rules for correctness/basic support.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we look for equal PhysicalExprs?

The Postgres docs imply they do ~text comparison but I'm not sure how accessible that info is at this layer.

let group_by_columns: Result<Vec<_>> =
group_exprs.iter().map(|(e, _)| column_index(e)).collect();
let group_by_columns = group_by_columns?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be 1 liner?


let arg_columns: Result<Vec<_>> =
grouping_args.iter().map(column_index).collect();
let expr_indices: Result<Vec<_>> = arg_columns?
.iter()
.map(|arg| {
group_by_columns
.iter()
.position(|gb| arg == gb)
.ok_or_else(|| {
internal_datafusion_err!("Invalid grouping set indices.")
})
})
.collect();

Ok(Box::new(GroupingAccumulator {
grouping_ids: vec![],
expr_indices: expr_indices?,
}))
}
}

impl AggregateUDFImpl for Grouping {
Expand All @@ -78,20 +138,145 @@ impl AggregateUDFImpl for Grouping {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
Ok(DataType::UInt32)
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new(
format_state_name(args.name, "grouping"),
DataType::Int32,
DataType::UInt32,
true,
)])
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
not_impl_err!(
"physical plan is not yet implemented for GROUPING aggregate function"
)
not_impl_err!("The GROUPING function requires a GROUP BY context.")
}

fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
// Use `create_grouping_accumulator` instead.
not_impl_err!("GROUPING is not supported when invoked this way.")
}
}

struct GroupingAccumulator {
// Grouping ID value for each group
grouping_ids: Vec<u32>,
// Indices of GROUPING arguments as they appear in the GROUPING SET
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have more details or example on indices?

expr_indices: Vec<usize>,
}

impl GroupingAccumulator {
fn mask_to_id(&self, mask: &[bool]) -> Result<u32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more description on this method, how it changes the mask

let mut id: u32 = 0;
// rightmost entry is the LSB
for (i, &idx) in self.expr_indices.iter().rev().enumerate() {
match mask.get(idx) {
Some(true) => id |= 1 << i,
Some(false) => {}
None => {
return internal_err!(
"Index out of bounds while calculating GROUPING id."
)
}
}
}
Ok(id)
}
}

impl GroupsAccumulator for GroupingAccumulator {
fn update_batch(
&mut self,
_values: &[ArrayRef],
_group_indices: &[usize],
_opt_filter: Option<&BooleanArray>,
_total_num_groups: usize,
) -> Result<()> {
// No-op since GROUPING doesn't care about values
Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
_opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to merge_batch");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we always expect only 1 array ?

self.grouping_ids.resize(total_num_groups, 0);
let other_ids = values[0].as_primitive::<UInt32Type>();
accumulate(group_indices, other_ids, None, |group_index, group_id| {
self.grouping_ids[group_index] |= group_id;
});
Ok(())
}

fn update_groupings(
&mut self,
group_indices: &[usize],
group_mask: &[bool],
total_num_groups: usize,
) -> Result<()> {
self.grouping_ids.resize(total_num_groups, 0);
let group_id = self.mask_to_id(group_mask)?;
for &group_idx in group_indices {
self.grouping_ids[group_idx] = group_id;
}
Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let values = emit_to.take_needed(&mut self.grouping_ids);
let values = UInt32Array::new(values.into(), None);
Ok(Arc::new(values))
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
self.evaluate(emit_to).map(|arr| vec![arr])
}

fn size(&self) -> usize {
self.grouping_ids.capacity() * std::mem::size_of::<u32>()
}
}

#[cfg(test)]
mod tests {
use crate::grouping::GroupingAccumulator;

#[test]
fn test_group_ids() {
let grouping = GroupingAccumulator {
grouping_ids: vec![],
expr_indices: vec![0, 1, 3, 2],
};
let cases = vec![
(0b0000, vec![false, false, false, false]),
(0b1000, vec![true, false, false, false]),
(0b0100, vec![false, true, false, false]),
(0b1010, vec![true, false, false, true]),
(0b1001, vec![true, false, true, false]),
];
for (expected, input) in cases {
assert_eq!(expected, grouping.mask_to_id(&input).unwrap());
}
}
#[test]
fn test_bad_index() {
let grouping = GroupingAccumulator {
grouping_ids: vec![],
expr_indices: vec![5],
};
let res = grouping.mask_to_id(&[false]);
assert!(res.is_err())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may want to check the error message as well

}
}
19 changes: 15 additions & 4 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ impl PartialEq for PhysicalGroupBy {
}
}

pub(crate) struct PhysicalGroupingSet {
/// Exprs/columns over which the grouping set is aggregated
values: Vec<ArrayRef>,
/// True if the corresponding value is null in this grouping set
mask: Vec<bool>,
}

enum StreamType {
AggregateStream(AggregateStream),
GroupedHash(GroupedHashAggregateStream),
Expand Down Expand Up @@ -1140,13 +1147,13 @@ fn evaluate_optional(
/// - `batch`: the `RecordBatch` to evaluate against
///
/// Returns: A Vec of Vecs of Array of results
/// The outer Vec appears to be for grouping sets
/// The outer Vec contains the grouping sets defined by `group_by.groups`
/// The inner Vec contains the results per expression
/// The inner-inner Array contains the results per row
pub(crate) fn evaluate_group_by(
group_by: &PhysicalGroupBy,
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
) -> Result<Vec<PhysicalGroupingSet>> {
let exprs: Vec<ArrayRef> = group_by
.expr
.iter()
Expand All @@ -1169,7 +1176,7 @@ pub(crate) fn evaluate_group_by(
.groups
.iter()
.map(|group| {
group
let v = group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets have more meaningful name?

.iter()
.enumerate()
.map(|(idx, is_null)| {
Expand All @@ -1179,7 +1186,11 @@ pub(crate) fn evaluate_group_by(
Arc::clone(&exprs[idx])
}
})
.collect()
.collect();
PhysicalGroupingSet {
values: v,
mask: group.clone(),
}
})
.collect())
}
Expand Down
Loading