Skip to content

Commit

Permalink
Reduce code duplication in PrimitiveGroupValueBuilder with const ge…
Browse files Browse the repository at this point in the history
…nerics
  • Loading branch information
alamb committed Oct 1, 2024
1 parent 84ac4f9 commit 4be646a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
// under the License.

use crate::aggregates::group_values::group_column::{
ByteGroupValueBuilder, GroupColumn, NonNullPrimitiveGroupValueBuilder,
PrimitiveGroupValueBuilder,
ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder,
};
use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
Expand Down Expand Up @@ -135,10 +134,10 @@ impl GroupValuesColumn {
macro_rules! instantiate_primitive {
($v:expr, $nullable:expr, $t:ty) => {
if $nullable {
let b = PrimitiveGroupValueBuilder::<$t>::new();
let b = PrimitiveGroupValueBuilder::<$t, true>::new();
$v.push(Box::new(b) as _)
} else {
let b = NonNullPrimitiveGroupValueBuilder::<$t>::new();
let b = PrimitiveGroupValueBuilder::<$t, false>::new();
$v.push(Box::new(b) as _)
}
};
Expand Down
104 changes: 36 additions & 68 deletions datafusion/physical-plan/src/aggregates/group_values/group_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,75 +63,25 @@ pub trait GroupColumn: Send + Sync {
fn take_n(&mut self, n: usize) -> ArrayRef;
}

/// An implementation of [`GroupColumn`] for primitive values which are known to have no nulls
#[derive(Debug)]
pub struct NonNullPrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
group_values: Vec<T::Native>,
}

impl<T> NonNullPrimitiveGroupValueBuilder<T>
where
T: ArrowPrimitiveType,
{
pub fn new() -> Self {
Self {
group_values: vec![],
}
}
}

impl<T: ArrowPrimitiveType> GroupColumn for NonNullPrimitiveGroupValueBuilder<T> {
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
// know input has no nulls
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

fn append_val(&mut self, array: &ArrayRef, row: usize) {
// input can't possibly have nulls, so don't worry about them
self.group_values.push(array.as_primitive::<T>().value(row))
}

fn len(&self) -> usize {
self.group_values.len()
}

fn size(&self) -> usize {
self.group_values.allocated_size()
}

fn build(self: Box<Self>) -> ArrayRef {
let Self { group_values } = *self;

let nulls = None;

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(group_values),
nulls,
))
}

fn take_n(&mut self, n: usize) -> ArrayRef {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = None;

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(first_n),
first_n_nulls,
))
}
}

/// An implementation of [`GroupColumn`] for primitive values which may have nulls
/// An implementation of [`GroupColumn`] for primitive values
///
/// Optimized to skip null buffer construction if the input is known to be non nullable
///
/// # Template parameters
///
/// `T`: the native Rust type that stores the data
/// `NULLABLE`: if the data can contain any nulls
#[derive(Debug)]
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType, const NULLABLE: bool> {
group_values: Vec<T::Native>,
nulls: MaybeNullBufferBuilder,
}

impl<T> PrimitiveGroupValueBuilder<T>
impl<T, const NULLABLE: bool> PrimitiveGroupValueBuilder<T, NULLABLE>
where
T: ArrowPrimitiveType,
{
/// Create a new `PrimitiveGroupValueBuilder`
pub fn new() -> Self {
Self {
group_values: vec![],
Expand All @@ -140,18 +90,32 @@ where
}
}

impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
for PrimitiveGroupValueBuilder<T, NULLABLE>
{
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
// Perf: skip null check (by short circuit) if input is not ullable
let null_match = if NULLABLE {
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
} else {
true
};

null_match
&& self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

fn append_val(&mut self, array: &ArrayRef, row: usize) {
if array.is_null(row) {
self.nulls.append(true);
self.group_values.push(T::default_value());
// Perf: skip null check if input can't have nulls
if NULLABLE {
if array.is_null(row) {
self.nulls.append(true);
self.group_values.push(T::default_value());
} else {
self.nulls.append(false);
self.group_values.push(array.as_primitive::<T>().value(row));
}
} else {
self.nulls.append(false);
self.group_values.push(array.as_primitive::<T>().value(row));
}
}
Expand All @@ -171,6 +135,9 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
} = *self;

let nulls = nulls.build();
if !NULLABLE {
assert!(nulls.is_none(), "unexpected nulls in non nullable input");
}

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(group_values),
Expand All @@ -180,7 +147,8 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {

fn take_n(&mut self, n: usize) -> ArrayRef {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = self.nulls.take_n(n);

let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None };

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(first_n),
Expand Down

0 comments on commit 4be646a

Please sign in to comment.