Skip to content

Commit

Permalink
Merge branch 'alamb/hash_agg_spike' of github.com:alamb/arrow-datafus…
Browse files Browse the repository at this point in the history
…ion into alamb/hash_agg_spike
  • Loading branch information
alamb committed Jul 3, 2023
2 parents 6cab205 + 6275a9f commit 587dc0e
Show file tree
Hide file tree
Showing 4 changed files with 405 additions and 15 deletions.
9 changes: 9 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,12 @@ opt-level = 3
overflow-checks = false
panic = 'unwind'
rpath = false

# TODO remove after 43 release
[patch.crates-io]
arrow = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" }
arrow-flight = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" }
arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" }
arrow-buffer = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" }
arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" }
parquet = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" }
22 changes: 9 additions & 13 deletions datafusion/core/src/physical_plan/aggregates/row_hash2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::task::{Context, Poll};
use std::vec;

use ahash::RandomState;
use arrow::row::{OwnedRow, RowConverter, SortField};
use arrow::row::{RowConverter, SortField, Rows};
use datafusion_physical_expr::hash_utils::create_hashes;
use futures::ready;
use futures::stream::{Stream, StreamExt};
Expand Down Expand Up @@ -165,11 +165,7 @@ pub(crate) struct GroupedHashAggregateStream2 {
///
/// The row format is used to compare group keys quickly. This is
/// especially important for multi-column group keys.
///
/// TODO, make this Rows (rather than Vec<OwnedRow> to reduce
/// allocations once
/// https://github.com/apache/arrow-rs/issues/4466 is available
group_values: Vec<OwnedRow>,
group_values: Rows,

/// scratch space for the current input Batch being
/// processed. Reused across batches here to avoid reallocations
Expand Down Expand Up @@ -241,7 +237,7 @@ impl GroupedHashAggregateStream2 {
let name = format!("GroupedHashAggregateStream2[{partition}]");
let reservation = MemoryConsumer::new(name).register(context.memory_pool());
let map = RawTable::with_capacity(0);
let group_by_values = vec![];
let group_by_values = row_converter.empty_rows(0, 0);
let current_group_indices = vec![];

timer.done();
Expand Down Expand Up @@ -398,7 +394,7 @@ impl GroupedHashAggregateStream2 {
// TODO update *allocated based on size of the row
// that was just pushed into
// aggr_state.group_by_values
group_rows.row(row) == self.group_values[*group_idx].row()
group_rows.row(row) == self.group_values.row(*group_idx)
});

let group_idx = match entry {
Expand All @@ -407,8 +403,8 @@ impl GroupedHashAggregateStream2 {
// 1.2 Need to create new entry for the group
None => {
// Add new entry to aggr_state and save newly created index
let group_idx = self.group_values.len();
self.group_values.push(group_rows.row(row).owned());
let group_idx = self.group_values.num_rows();
self.group_values.push(group_rows.row(row));

// for hasher function, use precomputed hash value
self.map.insert_accounted(
Expand Down Expand Up @@ -455,7 +451,7 @@ impl GroupedHashAggregateStream2 {
.zip(input_values.iter())
.zip(filter_values.iter());

let total_num_groups = self.group_values.len();
let total_num_groups = self.group_values.num_rows();

for ((acc, values), opt_filter) in t {
let acc_size_pre = acc.size();
Expand Down Expand Up @@ -499,13 +495,13 @@ impl GroupedHashAggregateStream2 {
impl GroupedHashAggregateStream2 {
/// Create an output RecordBatch with all group keys and accumulator states/values
fn create_batch_from_map(&mut self) -> Result<RecordBatch> {
if self.group_values.is_empty() {
if self.group_values.num_rows() == 0 {
let schema = self.schema.clone();
return Ok(RecordBatch::new_empty(schema));
}

// First output rows are the groups
let groups_rows = self.group_values.iter().map(|owned_row| owned_row.row());
let groups_rows = self.group_values.iter().map(|owned_row| owned_row);

let mut output: Vec<ArrayRef> = self.row_converter.convert_rows(groups_rows)?;

Expand Down
191 changes: 190 additions & 1 deletion datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@

use std::any::Any;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::BitAnd;
use std::sync::Arc;

use crate::aggregate::row_accumulator::RowAccumulator;
use crate::aggregate::utils::down_cast_any_ref;
use crate::{AggregateExpr, PhysicalExpr};
use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
use arrow::array::{Array, Int64Array};
use arrow::compute;
use arrow::compute::kernels::cast;
use arrow::datatypes::DataType;
use arrow::{array::ArrayRef, datatypes::Field};
use arrow_array::cast::AsArray;
use arrow_array::types::{Int32Type, Int64Type, UInt32Type, UInt64Type};
use arrow_array::{ArrowNumericType, PrimitiveArray, UInt64Array};
use arrow_buffer::BooleanBuffer;
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
Expand All @@ -37,6 +42,8 @@ use datafusion_row::accessor::RowAccessor;

use crate::expressions::format_state_name;

use super::groups_accumulator::accumulate::{accumulate_all, accumulate_all_nullable};

/// COUNT aggregate expression
/// Returns the amount of non-null values of the given expression.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -76,6 +83,165 @@ impl Count {
}
}

/// An accumulator to compute the average of PrimitiveArray<T>.
/// Stores values as native types, and does overflow checking
///
/// F: Function that calcuates the average value from a sum of
/// T::Native and a total count
#[derive(Debug)]
struct CountGroupsAccumulator<T>
where
T: ArrowNumericType + Send,
{
/// The type of the returned count
return_data_type: DataType,

/// Count per group (use u64 to make UInt64Array)
counts: Vec<u64>,
// Bind it to struct
phantom: PhantomData<T>,
}

impl<T> CountGroupsAccumulator<T>
where
T: ArrowNumericType + Send,
{
pub fn new(return_data_type: &DataType) -> Self {
Self {
return_data_type: return_data_type.clone(),
counts: vec![],
phantom: PhantomData {},
}
}

/// Adds one to each group's counter
fn increment_counts(
&mut self,
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) {
self.counts.resize(total_num_groups, 0);

if values.null_count() == 0 {
accumulate_all(
group_indices,
values,
opt_filter,
|group_index, _new_value| {
self.counts[group_index] += 1;
},
)
} else {
accumulate_all_nullable(
group_indices,
values,
opt_filter,
|group_index, _new_value, is_valid| {
if is_valid {
self.counts[group_index] += 1;
}
},
)
}
}

/// Adds the counts with the partial counts
fn update_counts_with_partial_counts(
&mut self,
group_indices: &[usize],
partial_counts: &UInt64Array,
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) {
self.counts.resize(total_num_groups, 0);

if partial_counts.null_count() == 0 {
accumulate_all(
group_indices,
partial_counts,
opt_filter,
|group_index, partial_count| {
self.counts[group_index] += partial_count;
},
)
} else {
accumulate_all_nullable(
group_indices,
partial_counts,
opt_filter,
|group_index, partial_count, is_valid| {
if is_valid {
self.counts[group_index] += partial_count;
}
},
)
}
}
}

impl<T> GroupsAccumulator for CountGroupsAccumulator<T>
where
T: ArrowNumericType + Send,
{
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values.get(0).unwrap().as_primitive::<T>();

self.increment_counts(group_indices, values, opt_filter, total_num_groups);

Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "one argument to merge_batch");
// first batch is counts, second is partial sums
let partial_counts = values.get(0).unwrap().as_primitive::<UInt64Type>();
self.update_counts_with_partial_counts(
group_indices,
partial_counts,
opt_filter,
total_num_groups,
);

Ok(())
}

fn evaluate(&mut self) -> Result<ArrayRef> {
let counts = std::mem::take(&mut self.counts);

let array = PrimitiveArray::<UInt64Type>::new(counts.into(), None);
// TODO remove cast
let array = cast(&array, &self.return_data_type)?;

Ok(array)
}

// return arrays for sums and counts
fn state(&mut self) -> Result<Vec<ArrayRef>> {
let counts = std::mem::take(&mut self.counts);
let counts = UInt64Array::from(counts); // zero copy
Ok(vec![Arc::new(counts) as ArrayRef])
}

fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<usize>()
}
}

/// count null values for multiple columns
/// for each row if one column value is null, then null_count + 1
fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
Expand Down Expand Up @@ -147,6 +313,29 @@ impl AggregateExpr for Count {
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CountAccumulator::new()))
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
// instantiate specialized accumulator
match &self.data_type {
DataType::UInt64 => Ok(Box::new(CountGroupsAccumulator::<UInt64Type>::new(
&self.data_type,
))),
DataType::Int64 => Ok(Box::new(CountGroupsAccumulator::<Int64Type>::new(
&self.data_type,
))),
DataType::UInt32 => Ok(Box::new(CountGroupsAccumulator::<UInt32Type>::new(
&self.data_type,
))),
DataType::Int32 => Ok(Box::new(CountGroupsAccumulator::<Int32Type>::new(
&self.data_type,
))),

_ => Err(DataFusionError::NotImplemented(format!(
"CountGroupsAccumulator not supported for {}",
self.data_type
))),
}
}
}

impl PartialEq<dyn Any> for Count {
Expand Down
Loading

0 comments on commit 587dc0e

Please sign in to comment.