Skip to content

Commit

Permalink
Fix sum accumulator with filtering, consolidate null handling
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 5, 2023
1 parent 68f62d1 commit ad6d4f3
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 218 deletions.
190 changes: 51 additions & 139 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
//! Defines physical expressions that can evaluated at runtime during query execution

use arrow::array::{AsArray, PrimitiveBuilder};
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use log::debug;

use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

use crate::aggregate::groups_accumulator::accumulate::NullState;
use crate::aggregate::row_accumulator::{
is_row_accumulator_support_dtype, RowAccumulator,
};
Expand All @@ -51,7 +51,6 @@ use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;

use super::groups_accumulator::{accumulate_all, accumulate_all_nullable};
use super::utils::{adjust_output_array, Decimal128Averager};

/// AVG aggregate expression
Expand Down Expand Up @@ -490,10 +489,10 @@ where
/// Sums per group, stored as the native type
sums: Vec<T::Native>,

/// If we have seen a null input value for this group_index
null_inputs: BooleanBufferBuilder,
/// Track nulls in the input / filters
null_state: NullState,

/// Function that computes the average (value / count)
/// Function that computes the final average (value / count)
avg_fn: F,
}

Expand All @@ -513,137 +512,10 @@ where
sum_data_type: sum_data_type.clone(),
counts: vec![],
sums: vec![],
null_inputs: BooleanBufferBuilder::new(0),
null_state: NullState::new(),
avg_fn,
}
}

/// Adds one to each group's counter
fn increment_counts(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) {
self.counts.resize(total_num_groups, 0);

if values.null_count() == 0 {
accumulate_all(
group_indices,
values,
opt_filter,
|group_index, _new_value| {
self.counts[group_index] += 1;
},
)
} else {
accumulate_all_nullable(
group_indices,
values,
opt_filter,
|group_index, _new_value, is_valid| {
if is_valid {
self.counts[group_index] += 1;
}
},
)
}
}

/// Adds the counts with the partial counts
fn update_counts_with_partial_counts(
&mut self,
group_indices: &[usize],
partial_counts: &UInt64Array,
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) {
self.counts.resize(total_num_groups, 0);

if partial_counts.null_count() == 0 {
accumulate_all(
group_indices,
partial_counts,
opt_filter,
|group_index, partial_count| {
self.counts[group_index] += partial_count;
},
)
} else {
accumulate_all_nullable(
group_indices,
partial_counts,
opt_filter,
|group_index, partial_count, is_valid| {
if is_valid {
self.counts[group_index] += partial_count;
}
},
)
}
}

/// Adds the values in `values` to self.sums
fn update_sums(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) {
if self.null_inputs.len() < total_num_groups {
let new_groups = total_num_groups - self.null_inputs.len();
// All groups start as valid (and are set to null if we
// see a null in the input)
self.null_inputs.append_n(new_groups, true);
}
self.sums
.resize_with(total_num_groups, || T::default_value());

if values.null_count() == 0 {
accumulate_all(
group_indices,
values,
opt_filter,
|group_index, new_value| {
// note since add_wrapping doesn't error, we
// simply add values in null sum slots rather than
// checking if they are null first. The theory is
// this is faster
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);
},
)
} else {
accumulate_all_nullable(
group_indices,
values,
opt_filter,
|group_index, new_value, is_valid| {
if is_valid {
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);
} else {
// input null means this group is now null
self.null_inputs.set_bit(group_index, false);
}
},
)
}
}

/// Returns a NullBuffer representing which group_indices have
/// null values (if they saw a null input)
/// Resets `self.null_inputs`;
fn build_nulls(&mut self) -> Option<NullBuffer> {
let nulls = NullBuffer::new(self.null_inputs.finish());
if nulls.null_count() > 0 {
Some(nulls)
} else {
None
}
}
}

impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
Expand All @@ -661,8 +533,30 @@ where
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values.get(0).unwrap().as_primitive::<T>();

self.increment_counts(group_indices, values, opt_filter, total_num_groups);
self.update_sums(group_indices, values, opt_filter, total_num_groups);
// increment counts
self.counts.resize(total_num_groups, 0);
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, _new_value| {
self.counts[group_index] += 1;
},
);

// update sums
self.sums.resize(total_num_groups, T::default_value());
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);
},
);

Ok(())
}
Expand All @@ -678,21 +572,39 @@ where
// first batch is counts, second is partial sums
let partial_counts = values.get(0).unwrap().as_primitive::<UInt64Type>();
let partial_sums = values.get(1).unwrap().as_primitive::<T>();
self.update_counts_with_partial_counts(
// update counts with partial counts
self.counts.resize(total_num_groups, 0);
self.null_state.accumulate(
group_indices,
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
self.counts[group_index] += partial_count;
},
);

// update sums
self.sums
.resize_with(total_num_groups, || T::default_value());
self.null_state.accumulate(
group_indices,
partial_sums,
opt_filter,
total_num_groups,
|group_index, new_value| {
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);
},
);
self.update_sums(group_indices, partial_sums, opt_filter, total_num_groups);

Ok(())
}

fn evaluate(&mut self) -> Result<ArrayRef> {
let counts = std::mem::take(&mut self.counts);
let sums = std::mem::take(&mut self.sums);
let nulls = self.build_nulls();
let nulls = self.null_state.build();

assert_eq!(counts.len(), sums.len());

Expand Down Expand Up @@ -727,7 +639,7 @@ where

// return arrays for sums and counts
fn state(&mut self) -> Result<Vec<ArrayRef>> {
let nulls = self.build_nulls();
let nulls = self.null_state.build();
let counts = std::mem::take(&mut self.counts);
let counts = UInt64Array::from(counts); // zero copy

Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ use datafusion_row::accessor::RowAccessor;
use crate::expressions::format_state_name;

use super::groups_accumulator::accumulate::{
accumulate_all, accumulate_indices, accumulate_indices_nullable,
accumulate_all, accumulate_all_nullable, accumulate_indices,
accumulate_indices_nullable,
};
use super::groups_accumulator::accumulate_all_nullable;

/// COUNT aggregate expression
/// Returns the amount of non-null values of the given expression.
Expand Down
Loading

0 comments on commit ad6d4f3

Please sign in to comment.