diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 91d87302ce99..fc598a41d276 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -16,8 +16,7 @@ // under the License. use crate::aggregates::group_values::group_column::{ - ByteGroupValueBuilder, GroupColumn, NonNullPrimitiveGroupValueBuilder, - PrimitiveGroupValueBuilder, + ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder, }; use crate::aggregates::group_values::GroupValues; use ahash::RandomState; @@ -135,10 +134,10 @@ impl GroupValuesColumn { macro_rules! instantiate_primitive { ($v:expr, $nullable:expr, $t:ty) => { if $nullable { - let b = PrimitiveGroupValueBuilder::<$t>::new(); + let b = PrimitiveGroupValueBuilder::<$t, true>::new(); $v.push(Box::new(b) as _) } else { - let b = NonNullPrimitiveGroupValueBuilder::<$t>::new(); + let b = PrimitiveGroupValueBuilder::<$t, false>::new(); $v.push(Box::new(b) as _) } }; diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index 7409f5c214b9..27e027f85155 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -63,75 +63,25 @@ pub trait GroupColumn: Send + Sync { fn take_n(&mut self, n: usize) -> ArrayRef; } -/// An implementation of [`GroupColumn`] for primitive values which are known to have no nulls -#[derive(Debug)] -pub struct NonNullPrimitiveGroupValueBuilder { - group_values: Vec, -} - -impl NonNullPrimitiveGroupValueBuilder -where - T: ArrowPrimitiveType, -{ - pub fn new() -> Self { - Self { - group_values: vec![], - } - } -} - -impl GroupColumn for NonNullPrimitiveGroupValueBuilder { - fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - // know input has no nulls - self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) - } - - fn append_val(&mut self, array: &ArrayRef, row: usize) { - // input can't possibly have nulls, so don't worry about them - self.group_values.push(array.as_primitive::().value(row)) - } - - fn len(&self) -> usize { - self.group_values.len() - } - - fn size(&self) -> usize { - self.group_values.allocated_size() - } - - fn build(self: Box) -> ArrayRef { - let Self { group_values } = *self; - - let nulls = None; - - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(group_values), - nulls, - )) - } - - fn take_n(&mut self, n: usize) -> ArrayRef { - let first_n = self.group_values.drain(0..n).collect::>(); - let first_n_nulls = None; - - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(first_n), - first_n_nulls, - )) - } -} - -/// An implementation of [`GroupColumn`] for primitive values which may have nulls +/// An implementation of [`GroupColumn`] for primitive values +/// +/// Optimized to skip null buffer construction if the input is known to be non nullable +/// +/// # Template parameters +/// +/// `T`: the native Rust type that stores the data +/// `NULLABLE`: if the data can contain any nulls #[derive(Debug)] -pub struct PrimitiveGroupValueBuilder { +pub struct PrimitiveGroupValueBuilder { group_values: Vec, nulls: MaybeNullBufferBuilder, } -impl PrimitiveGroupValueBuilder +impl PrimitiveGroupValueBuilder where T: ArrowPrimitiveType, { + /// Create a new `PrimitiveGroupValueBuilder` pub fn new() -> Self { Self { group_values: vec![], @@ -140,18 +90,32 @@ where } } -impl GroupColumn for PrimitiveGroupValueBuilder { +impl GroupColumn + for PrimitiveGroupValueBuilder +{ fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - self.nulls.is_null(lhs_row) == array.is_null(rhs_row) + // Perf: skip null check (by short circuit) if input is not ullable + let null_match = if NULLABLE { + self.nulls.is_null(lhs_row) == array.is_null(rhs_row) + } else { + true + }; + + null_match && self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } fn append_val(&mut self, array: &ArrayRef, row: usize) { - if array.is_null(row) { - self.nulls.append(true); - self.group_values.push(T::default_value()); + // Perf: skip null check if input can't have nulls + if NULLABLE { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(array.as_primitive::().value(row)); + } } else { - self.nulls.append(false); self.group_values.push(array.as_primitive::().value(row)); } } @@ -171,6 +135,9 @@ impl GroupColumn for PrimitiveGroupValueBuilder { } = *self; let nulls = nulls.build(); + if !NULLABLE { + assert!(nulls.is_none(), "unexpected nulls in non nullable input"); + } Arc::new(PrimitiveArray::::new( ScalarBuffer::from(group_values), @@ -180,7 +147,8 @@ impl GroupColumn for PrimitiveGroupValueBuilder { fn take_n(&mut self, n: usize) -> ArrayRef { let first_n = self.group_values.drain(0..n).collect::>(); - let first_n_nulls = self.nulls.take_n(n); + + let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; Arc::new(PrimitiveArray::::new( ScalarBuffer::from(first_n),