Skip to content

Commit

Permalink
Support convert_to_state for AVG accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 30, 2024
1 parent 0d994a6 commit ee5ac1c
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

use arrow::array::{
self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
AsArray, BooleanArray, Int64Array, PrimitiveArray, PrimitiveBuilder, UInt64Array,
};
use arrow::buffer::NullBuffer;
use arrow::compute::sum;
use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Expand Down Expand Up @@ -554,8 +555,51 @@ where
Ok(())
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let counts = Arc::new(Int64Array::from_value(1, values.len()));
let sums = values[0].as_primitive::<T>();

let nulls = filtered_null_mask(opt_filter, sums);
let sums = PrimitiveArray::<T>::new(sums.values().clone(), nulls)
.with_data_type(self.sum_data_type.clone());

Ok(vec![counts, Arc::new(sums)])
}

fn convert_to_state_supported(&self) -> bool {
true
}

fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<u64>()
+ self.sums.capacity() * std::mem::size_of::<T>()
}
}

/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer`
/// where the NullBuffer is true for all values that were true
/// in the filter and `null` for any values that were false or null
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
let (filter_bools, filter_nulls) = filter.clone().into_parts();
// Only keep values where the filter was true
// convert all false to null
let filter_bools = NullBuffer::from(filter_bools);
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
}

/// Compute the final null mask for an array
///
/// The output null mask :
/// * is true (non null) for all values that were true in the filter and non null in the input
/// * is false (null) for all values that were false in the filter or null in the input
fn filtered_null_mask(
opt_filter: Option<&BooleanArray>,
input: &dyn Array,
) -> Option<NullBuffer> {
let opt_filter = opt_filter.and_then(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}

0 comments on commit ee5ac1c

Please sign in to comment.