Skip to content

Commit

Permalink
Min/Max for primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniël Heres committed Jul 5, 2023
1 parent 24abb14 commit 6e740a4
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,29 @@ mod test {

let null_buffer = null_state.build();

if null_buffer != expected_null_buffer {
if let (Some(null_buffer), Some(expected_null_buffer)) =
(null_buffer.as_ref(), expected_null_buffer.as_ref())
{
null_buffer
.iter()
.zip(expected_null_buffer.iter())
.enumerate()
.for_each(|(i, (valid, expected_valid))| {
println!(
"nulls[{i}]: valid: {valid}, expected: {expected_valid}"
);
println!(
" expected_seen_values: {} expected_null_input: {}",
expected_seen_values.contains(&i),
expected_null_input.contains(&i)
);

assert_eq!(valid, expected_valid, "Index {i}");
})
};
}

assert_eq!(null_buffer, expected_null_buffer);
}

Expand Down
252 changes: 251 additions & 1 deletion datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

use crate::{AggregateExpr, PhysicalExpr};
use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
use arrow::compute;
use arrow::datatypes::{DataType, TimeUnit};
use arrow::{
Expand All @@ -35,9 +35,15 @@ use arrow::{
},
datatypes::Field,
};
use arrow_array::cast::AsArray;
use arrow_array::types::{
ArrowPrimitiveType, Decimal128Type, Float32Type, Float64Type, UInt32Type, UInt64Type,
};
use arrow_array::{ArrowNumericType, PrimitiveArray};
use datafusion_common::ScalarValue;
use datafusion_common::{downcast_value, DataFusionError, Result};
use datafusion_expr::Accumulator;
use log::debug;

use crate::aggregate::row_accumulator::{
is_row_accumulator_support_dtype, RowAccumulator,
Expand All @@ -48,7 +54,9 @@ use arrow::array::Array;
use arrow::array::Decimal128Array;
use datafusion_row::accessor::RowAccessor;

use super::groups_accumulator::accumulate::NullState;
use super::moving_min_max;
use super::utils::adjust_output_array;

// Min/max aggregation can take Dictionary encode input but always produces unpacked
// (aka non Dictionary) output. We need to adjust the output data type to reflect this.
Expand Down Expand Up @@ -125,6 +133,10 @@ impl AggregateExpr for Max {
is_row_accumulator_support_dtype(&self.data_type)
}

fn groups_accumulator_supported(&self) -> bool {
self.data_type.is_primitive()
}

fn create_row_accumulator(
&self,
start_index: usize,
Expand All @@ -135,6 +147,47 @@ impl AggregateExpr for Max {
)))
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
match self.data_type {
DataType::UInt32 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
UInt32Type,
false,
>::new(
&self.data_type, &self.data_type
))),
DataType::UInt64 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
UInt64Type,
false,
>::new(
&self.data_type, &self.data_type
))),
DataType::Float32 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
Float32Type,
false,
>::new(
&self.data_type, &self.data_type
))),
DataType::Float64 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
Float64Type,
false,
>::new(
&self.data_type, &self.data_type
))),
DataType::Decimal128(_, _) => {
Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
Decimal128Type,
false,
>::new(
&self.data_type, &self.data_type
)))
}
_ => Err(DataFusionError::NotImplemented(format!(
"MinMaxGroupsPrimitiveAccumulator not supported for {}",
self.data_type
))),
}
}

fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
Some(Arc::new(self.clone()))
}
Expand Down Expand Up @@ -835,6 +888,55 @@ impl AggregateExpr for Min {
)))
}

fn groups_accumulator_supported(&self) -> bool {
Max::groups_accumulator_supported(&Max::new(
self.expr.clone(),
self.name.clone(),
self.data_type.clone(),
))
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
match self.data_type {
DataType::UInt32 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
UInt32Type,
true,
>::new(
&self.data_type, &self.data_type
))),
DataType::UInt64 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
UInt64Type,
true,
>::new(
&self.data_type, &self.data_type
))),
DataType::Float32 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
Float32Type,
true,
>::new(
&self.data_type, &self.data_type
))),
DataType::Float64 => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
Float64Type,
true,
>::new(
&self.data_type, &self.data_type
))),
DataType::Decimal128(_, _) => {
Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
Decimal128Type,
true,
>::new(
&self.data_type, &self.data_type
)))
}
_ => Err(DataFusionError::NotImplemented(format!(
"MinMaxGroupsPrimitiveAccumulator not supported for {}",
self.data_type
))),
}
}

fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
Some(Arc::new(self.clone()))
}
Expand Down Expand Up @@ -1022,6 +1124,154 @@ impl RowAccumulator for MinRowAccumulator {
}
}

/// An accumulator to compute the min or max of PrimitiveArray<T>.
/// Stores values as native types, and does overflow checking
#[derive(Debug)]
struct MinMaxGroupsPrimitiveAccumulator<T, const MIN: bool>
where
T: ArrowNumericType + Send,
{
/// The type of the computed sum
min_max_data_type: DataType,

/// The type of the returned sum
return_data_type: DataType,

/// Min/max per group, stored as the native type
min_max: Vec<T::Native>,

/// Track nulls in the input / filters
null_state: NullState,
}

impl<T, const MIN: bool> MinMaxGroupsPrimitiveAccumulator<T, MIN>
where
T: ArrowNumericType + Send,
{
pub fn new(min_max_data_type: &DataType, return_data_type: &DataType) -> Self {
debug!(
"MinMaxGroupsPrimitiveAccumulator ({}, sum type: {min_max_data_type:?}) --> {return_data_type:?}",
std::any::type_name::<T>()
);

Self {
return_data_type: return_data_type.clone(),
min_max_data_type: min_max_data_type.clone(),
min_max: vec![],
null_state: NullState::new(),
}
}
}

impl<T, const MIN: bool> GroupsAccumulator for MinMaxGroupsPrimitiveAccumulator<T, MIN>
where
T: ArrowNumericType + Send,
{
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values.get(0).unwrap().as_primitive::<T>();

// update sums
self.min_max
.resize_with(total_num_groups, || T::default_value());

// NullState dispatches / handles tracking nulls and groups that saw no values
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let val: &mut <T as ArrowPrimitiveType>::Native =
&mut self.min_max[group_index];
if MIN {
if new_value < *val {
*val = new_value;
}
} else {
if new_value > *val {
*val = new_value;
}
}
},
);

Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
// first batch is partial sums
let partial_min_max: &PrimitiveArray<T> =
values.get(0).unwrap().as_primitive::<T>();

// Sum partial sums
self.min_max
.resize_with(total_num_groups, || T::default_value());

self.null_state.accumulate(
group_indices,
partial_min_max,
opt_filter,
total_num_groups,
|group_index, new_value| {
let val = &mut self.min_max[group_index];
// TODO: support min and max
if MIN {
if new_value < *val {
*val = new_value;
}
} else {
if new_value > *val {
*val = new_value;
}
}
},
);

Ok(())
}

fn evaluate(&mut self) -> Result<ArrayRef> {
let min_max = std::mem::take(&mut self.min_max);
let nulls = self.null_state.build();

let min_max = PrimitiveArray::<T>::new(min_max.into(), nulls); // no copy
let min_max = adjust_output_array(&self.return_data_type, Arc::new(min_max))?;

Ok(Arc::new(min_max))
}

// return arrays for sums and counts
fn state(&mut self) -> Result<Vec<ArrayRef>> {
let nulls = self.null_state.build();

let min_max = std::mem::take(&mut self.min_max);
let min_max = Arc::new(PrimitiveArray::<T>::new(min_max.into(), nulls.clone())); // zero copy

let sums = adjust_output_array(&self.min_max_data_type, min_max)?;

// TODO: Sum expects sum/count array, but count is not needed
Ok(vec![sums.clone() as ArrayRef])
}

fn size(&self) -> usize {
self.min_max.capacity() * std::mem::size_of::<usize>()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 2 additions & 1 deletion datafusion/physical-expr/src/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ impl RowAccumulator for SumRowAccumulator {
}
}

/// An accumulator to compute the sum of values in [`PrimitiveArray<T>`]
/// An accumulator to compute the sum of PrimitiveArray<T>.
/// Stores values as native types, and does overflow checking
#[derive(Debug)]
struct SumGroupsAccumulator<T>
where
Expand Down

0 comments on commit 6e740a4

Please sign in to comment.