diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 18642fb84329..3ca462001488 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -19,8 +19,9 @@ use arrow::array::{ self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + AsArray, BooleanArray, Int64Array, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; +use arrow::buffer::NullBuffer; use arrow::compute::sum; use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, @@ -554,8 +555,51 @@ where Ok(()) } + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let counts = Arc::new(Int64Array::from_value(1, values.len())); + let sums = values[0].as_primitive::(); + + let nulls = filtered_null_mask(opt_filter, sums); + let sums = PrimitiveArray::::new(sums.values().clone(), nulls) + .with_data_type(self.sum_data_type.clone()); + + Ok(vec![counts, Arc::new(sums)]) + } + + fn convert_to_state_supported(&self) -> bool { + true + } + fn size(&self) -> usize { self.counts.capacity() * std::mem::size_of::() + self.sums.capacity() * std::mem::size_of::() } } + +/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer` +/// where the NullBuffer is true for all values that were true +/// in the filter and `null` for any values that were false or null +fn filter_to_nulls(filter: &BooleanArray) -> Option { + let (filter_bools, filter_nulls) = filter.clone().into_parts(); + // Only keep values where the filter was true + // convert all false to null + let filter_bools = NullBuffer::from(filter_bools); + NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) +} + +/// Compute the final null mask for an array +/// +/// The output null mask : +/// * is true (non null) for all values that were true in the filter and non null in the input +/// * is false (null) for all values that were false in the filter or null in the input +fn filtered_null_mask( + opt_filter: Option<&BooleanArray>, + input: &dyn Array, +) -> Option { + let opt_filter = opt_filter.and_then(filter_to_nulls); + NullBuffer::union(opt_filter.as_ref(), input.nulls()) +}