From 316c78173aa3d2298422bbd6f075b0d40c82a776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 3 Jul 2023 10:09:33 +0200 Subject: [PATCH 1/5] WIP count --- .../physical-expr/src/aggregate/count.rs | 238 +++++++++++++++++- 1 file changed, 236 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 22cb2512fc42..c3ad7767b1c7 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -19,17 +19,23 @@ use std::any::Any; use std::fmt::Debug; +use std::marker::PhantomData; use std::ops::BitAnd; use std::sync::Arc; use crate::aggregate::row_accumulator::RowAccumulator; use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr, GroupsAccumulator}; use arrow::array::{Array, Int64Array}; use arrow::compute; +use arrow::compute::kernels::cast; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_buffer::BooleanBuffer; +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::cast::AsArray; +use arrow_array::types::{UInt64Type, Int64Type, UInt32Type, Int32Type}; +use arrow_array::{PrimitiveArray, UInt64Array, ArrowNumericType}; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -37,6 +43,8 @@ use datafusion_row::accessor::RowAccessor; use crate::expressions::format_state_name; +use super::groups_accumulator::accumulate::{accumulate_all, accumulate_all_nullable}; + /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. #[derive(Debug, Clone)] @@ -76,6 +84,200 @@ impl Count { } } +/// An accumulator to compute the average of PrimitiveArray. +/// Stores values as native types, and does overflow checking +/// +/// F: Function that calcuates the average value from a sum of +/// T::Native and a total count +#[derive(Debug)] +struct CountGroupsAccumulator +where T: ArrowNumericType + Send, +{ + /// The type of the returned count + return_data_type: DataType, + + /// Count per group (use u64 to make UInt64Array) + counts: Vec, + + /// If we have seen a null input value for this group_index + null_inputs: BooleanBufferBuilder, + + // Bind it to struct + phantom: PhantomData +} + + +impl CountGroupsAccumulator +where T: ArrowNumericType + Send, +{ + pub fn new(return_data_type: &DataType) -> Self { + Self { + return_data_type: return_data_type.clone(), + counts: vec![], + null_inputs: BooleanBufferBuilder::new(0), + phantom: PhantomData {} + } + } + + /// Adds one to each group's counter + fn increment_counts( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if values.null_count() == 0 { + accumulate_all( + group_indices, + values, + opt_filter, + |group_index, _new_value| { + self.counts[group_index] += 1; + } + ) + }else { + accumulate_all_nullable( + group_indices, + values, + opt_filter, + |group_index, _new_value, is_valid| { + if is_valid { + self.counts[group_index] += 1; + } + }, + ) + } + } + + /// Adds the counts with the partial counts + fn update_counts_with_partial_counts( + &mut self, + group_indices: &[usize], + partial_counts: &UInt64Array, + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if partial_counts.null_count() == 0 { + accumulate_all( + group_indices, + partial_counts, + opt_filter, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ) + } else { + accumulate_all_nullable( + group_indices, + partial_counts, + opt_filter, + |group_index, partial_count, is_valid| { + if is_valid { + self.counts[group_index] += partial_count; + } + }, + ) + } + } + + /// Returns a NullBuffer representing which group_indices have + /// null values (if they saw a null input) + /// Resets `self.null_inputs`; + fn build_nulls(&mut self) -> Option { + let nulls = NullBuffer::new(self.null_inputs.finish()); + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator +where T: ArrowNumericType + Send +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_primitive::(); + + self.increment_counts(group_indices, values, opt_filter, total_num_groups); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values.get(0).unwrap().as_primitive::(); + self.update_counts_with_partial_counts( + group_indices, + partial_counts, + opt_filter, + total_num_groups, + ); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let counts = std::mem::take(&mut self.counts); + let nulls = self.build_nulls(); + + // don't evaluate averages with null inputs to avoid errors on null vaues + let array: PrimitiveArray = if let Some(nulls) = nulls.as_ref() { + let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); + let iter = counts.into_iter().zip(nulls.iter()); + + for (count, is_valid) in iter { + if is_valid { + builder.append_value(count) + } else { + builder.append_null(); + } + } + builder.finish() + } else { + PrimitiveArray::::new(counts.into(), nulls) // no copy + }; + // TODO remove cast + let array = cast(&array, &self.return_data_type)?; + + Ok(array) + } + + // return arrays for sums and counts + fn state(&mut self) -> Result> { + // TODO nulls + let nulls = self.build_nulls(); + let counts = std::mem::take(&mut self.counts); + let counts = UInt64Array::from(counts); // zero copy + Ok(vec![ + Arc::new(counts) as ArrayRef, + ]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + /// count null values for multiple columns /// for each row if one column value is null, then null_count + 1 fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { @@ -147,6 +349,38 @@ impl AggregateExpr for Count { fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(CountAccumulator::new())) } + + fn create_groups_accumulator(&self) -> Result> { + // instantiate specialized accumulator + match &self.data_type { + DataType::UInt64 => { + Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))) + }, + DataType::Int64 => { + Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))) + }, + DataType::UInt32 => { + Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))) + }, + DataType::Int32 => { + Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))) + } + + _ => Err(DataFusionError::NotImplemented(format!( + "CountGroupsAccumulator not supported for {}", + self.data_type + ))), + } + + } } impl PartialEq for Count { From 754a9ffe5bc7cb23871d0b8a78d8cbcd6860ca79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 3 Jul 2023 12:11:06 +0200 Subject: [PATCH 2/5] WIP count --- .../physical-expr/src/aggregate/count.rs | 219 +++++++----------- 1 file changed, 87 insertions(+), 132 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index c3ad7767b1c7..47b7588ec518 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -25,17 +25,16 @@ use std::sync::Arc; use crate::aggregate::row_accumulator::RowAccumulator; use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr, GroupsAccumulator}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::array::{Array, Int64Array}; use arrow::compute; use arrow::compute::kernels::cast; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_array::builder::PrimitiveBuilder; use arrow_array::cast::AsArray; -use arrow_array::types::{UInt64Type, Int64Type, UInt32Type, Int32Type}; -use arrow_array::{PrimitiveArray, UInt64Array, ArrowNumericType}; -use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; +use arrow_array::types::{Int32Type, Int64Type, UInt32Type, UInt64Type}; +use arrow_array::{ArrowNumericType, PrimitiveArray, UInt64Array}; +use arrow_buffer::BooleanBuffer; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -91,115 +90,100 @@ impl Count { /// T::Native and a total count #[derive(Debug)] struct CountGroupsAccumulator -where T: ArrowNumericType + Send, +where + T: ArrowNumericType + Send, { /// The type of the returned count return_data_type: DataType, /// Count per group (use u64 to make UInt64Array) counts: Vec, - - /// If we have seen a null input value for this group_index - null_inputs: BooleanBufferBuilder, - // Bind it to struct - phantom: PhantomData + phantom: PhantomData, } - impl CountGroupsAccumulator -where T: ArrowNumericType + Send, +where + T: ArrowNumericType + Send, { pub fn new(return_data_type: &DataType) -> Self { Self { return_data_type: return_data_type.clone(), counts: vec![], - null_inputs: BooleanBufferBuilder::new(0), - phantom: PhantomData {} + phantom: PhantomData {}, } } - /// Adds one to each group's counter - fn increment_counts( - &mut self, - group_indices: &[usize], - values: &PrimitiveArray, - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) { - self.counts.resize(total_num_groups, 0); - - if values.null_count() == 0 { - accumulate_all( - group_indices, - values, - opt_filter, - |group_index, _new_value| { + /// Adds one to each group's counter + fn increment_counts( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if values.null_count() == 0 { + accumulate_all( + group_indices, + values, + opt_filter, + |group_index, _new_value| { + self.counts[group_index] += 1; + }, + ) + } else { + accumulate_all_nullable( + group_indices, + values, + opt_filter, + |group_index, _new_value, is_valid| { + if is_valid { self.counts[group_index] += 1; } - ) - }else { - accumulate_all_nullable( - group_indices, - values, - opt_filter, - |group_index, _new_value, is_valid| { - if is_valid { - self.counts[group_index] += 1; - } - }, - ) - } + }, + ) } + } - /// Adds the counts with the partial counts - fn update_counts_with_partial_counts( - &mut self, - group_indices: &[usize], - partial_counts: &UInt64Array, - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) { - self.counts.resize(total_num_groups, 0); - - if partial_counts.null_count() == 0 { - accumulate_all( - group_indices, - partial_counts, - opt_filter, - |group_index, partial_count| { + /// Adds the counts with the partial counts + fn update_counts_with_partial_counts( + &mut self, + group_indices: &[usize], + partial_counts: &UInt64Array, + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if partial_counts.null_count() == 0 { + accumulate_all( + group_indices, + partial_counts, + opt_filter, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ) + } else { + accumulate_all_nullable( + group_indices, + partial_counts, + opt_filter, + |group_index, partial_count, is_valid| { + if is_valid { self.counts[group_index] += partial_count; - }, - ) - } else { - accumulate_all_nullable( - group_indices, - partial_counts, - opt_filter, - |group_index, partial_count, is_valid| { - if is_valid { - self.counts[group_index] += partial_count; - } - }, - ) - } - } - - /// Returns a NullBuffer representing which group_indices have - /// null values (if they saw a null input) - /// Resets `self.null_inputs`; - fn build_nulls(&mut self) -> Option { - let nulls = NullBuffer::new(self.null_inputs.finish()); - if nulls.null_count() > 0 { - Some(nulls) - } else { - None - } + } + }, + ) } + } } -impl GroupsAccumulator for CountGroupsAccumulator -where T: ArrowNumericType + Send +impl GroupsAccumulator for CountGroupsAccumulator +where + T: ArrowNumericType + Send, { fn update_batch( &mut self, @@ -238,24 +222,8 @@ where T: ArrowNumericType + Send fn evaluate(&mut self) -> Result { let counts = std::mem::take(&mut self.counts); - let nulls = self.build_nulls(); - - // don't evaluate averages with null inputs to avoid errors on null vaues - let array: PrimitiveArray = if let Some(nulls) = nulls.as_ref() { - let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); - let iter = counts.into_iter().zip(nulls.iter()); - - for (count, is_valid) in iter { - if is_valid { - builder.append_value(count) - } else { - builder.append_null(); - } - } - builder.finish() - } else { - PrimitiveArray::::new(counts.into(), nulls) // no copy - }; + + let array = PrimitiveArray::::new(counts.into(), None); // TODO remove cast let array = cast(&array, &self.return_data_type)?; @@ -264,13 +232,9 @@ where T: ArrowNumericType + Send // return arrays for sums and counts fn state(&mut self) -> Result> { - // TODO nulls - let nulls = self.build_nulls(); let counts = std::mem::take(&mut self.counts); let counts = UInt64Array::from(counts); // zero copy - Ok(vec![ - Arc::new(counts) as ArrayRef, - ]) + Ok(vec![Arc::new(counts) as ArrayRef]) } fn size(&self) -> usize { @@ -353,33 +317,24 @@ impl AggregateExpr for Count { fn create_groups_accumulator(&self) -> Result> { // instantiate specialized accumulator match &self.data_type { - DataType::UInt64 => { - Ok(Box::new(CountGroupsAccumulator::::new( - &self.data_type, - ))) - }, - DataType::Int64 => { - Ok(Box::new(CountGroupsAccumulator::::new( - &self.data_type, - ))) - }, - DataType::UInt32 => { - Ok(Box::new(CountGroupsAccumulator::::new( - &self.data_type, - ))) - }, - DataType::Int32 => { - Ok(Box::new(CountGroupsAccumulator::::new( - &self.data_type, - ))) - } + DataType::UInt64 => Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))), + DataType::Int64 => Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))), + DataType::UInt32 => Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))), + DataType::Int32 => Ok(Box::new(CountGroupsAccumulator::::new( + &self.data_type, + ))), _ => Err(DataFusionError::NotImplemented(format!( "CountGroupsAccumulator not supported for {}", self.data_type ))), } - } } From 689e51b461e8779856f94717223bbaaab4734e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 3 Jul 2023 16:08:55 +0200 Subject: [PATCH 3/5] WIP sum --- datafusion/physical-expr/src/aggregate/sum.rs | 198 +++++++++++++++++- 1 file changed, 197 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index efa55f060264..425901f91a90 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{ @@ -31,8 +31,13 @@ use arrow::{ }, datatypes::Field, }; +use arrow_array::cast::AsArray; +use arrow_array::types::{UInt64Type, Int64Type, UInt32Type, Int32Type, Decimal128Type}; +use arrow_array::{ArrowNativeTypeOp, ArrowNumericType, PrimitiveArray}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; +use log::debug; use crate::aggregate::row_accumulator::{ is_row_accumulator_support_dtype, RowAccumulator, @@ -44,6 +49,8 @@ use arrow::array::Decimal128Array; use arrow::compute::cast; use datafusion_row::accessor::RowAccessor; +use super::groups_accumulator::accumulate::{accumulate_all, accumulate_all_nullable}; + /// SUM aggregate expression #[derive(Debug, Clone)] pub struct Sum { @@ -141,6 +148,34 @@ impl AggregateExpr for Sum { ))) } + fn create_groups_accumulator(&self) -> Result> { + // instantiate specialized accumulator + match self.data_type { + DataType::UInt64 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, &self.data_type + ))), + DataType::Int64 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, &self.data_type + ))), + DataType::UInt32 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, &self.data_type + ))), + DataType::Int32 => Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, &self.data_type + ))), + DataType::Decimal128(_target_precision, _target_scale) => { + Ok(Box::new(SumGroupsAccumulator::::new( + &self.data_type, &self.data_type + ))) + } + _ => Err(DataFusionError::NotImplemented(format!( + "SumGroupsAccumulator not supported for {}", + self.data_type + ))), + } + } + + fn reverse_expr(&self) -> Option> { Some(Arc::new(self.clone())) } @@ -424,6 +459,167 @@ impl RowAccumulator for SumRowAccumulator { } } +/// An accumulator to compute the average of PrimitiveArray. +/// Stores values as native types, and does overflow checking +/// +/// F: Function that calcuates the average value from a sum of +/// T::Native and a total count +#[derive(Debug)] +struct SumGroupsAccumulator +where + T: ArrowNumericType + Send, +{ + /// The type of the internal sum + sum_data_type: DataType, + + /// The type of the returned sum + return_data_type: DataType, + + /// Sums per group, stored as the native type + sums: Vec, + + /// If we have seen a null input value for this group_index + null_inputs: BooleanBufferBuilder, +} + +impl SumGroupsAccumulator +where + T: ArrowNumericType + Send, +{ + pub fn new(sum_data_type: &DataType, return_data_type: &DataType) -> Self { + debug!( + "SumGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}", + std::any::type_name::() + ); + + Self { + return_data_type: return_data_type.clone(), + sum_data_type: sum_data_type.clone(), + sums: vec![], + null_inputs: BooleanBufferBuilder::new(0), + } + } + + /// Adds the values in `values` to self.sums + fn update_sums( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + if self.null_inputs.len() < total_num_groups { + let new_groups = total_num_groups - self.null_inputs.len(); + // All groups start as valid (and are set to null if we + // see a null in the input) + self.null_inputs.append_n(new_groups, true); + } + self.sums + .resize_with(total_num_groups, || T::default_value()); + + if values.null_count() == 0 { + accumulate_all( + group_indices, + values, + opt_filter, + |group_index, new_value| { + // note since add_wrapping doesn't error, we + // simply add values in null sum slots rather than + // checking if they are null first. The theory is + // this is faster + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + }, + ) + } else { + accumulate_all_nullable( + group_indices, + values, + opt_filter, + |group_index, new_value, is_valid| { + if is_valid { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + } else { + // input null means this group is now null + self.null_inputs.set_bit(group_index, false); + } + }, + ) + } + } + + /// Returns a NullBuffer representing which group_indices have + /// null values (if they saw a null input) + /// Resets `self.null_inputs`; + fn build_nulls(&mut self) -> Option { + let nulls = NullBuffer::new(self.null_inputs.finish()); + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } + } +} + +impl GroupsAccumulator for SumGroupsAccumulator +where + T: ArrowNumericType + Send, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_primitive::(); + + self.update_sums(group_indices, values, opt_filter, total_num_groups); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "two arguments to merge_batch"); + // first batch is partial sums + let partial_sums: &PrimitiveArray = values.get(1).unwrap().as_primitive::(); + self.update_sums(group_indices, partial_sums, opt_filter, total_num_groups); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let sums = std::mem::take(&mut self.sums); + let nulls = self.build_nulls(); + + let array = PrimitiveArray::::new(sums.into(), nulls); // no copy + + Ok(Arc::new(array)) + } + + // return arrays for sums and counts + fn state(&mut self) -> Result> { + let nulls = self.build_nulls(); + + let sums = std::mem::take(&mut self.sums); + let sums = PrimitiveArray::::new(sums.into(), nulls); // zero copy + + Ok(vec![Arc::new(sums) as ArrayRef]) + } + + fn size(&self) -> usize { + self.sums.capacity() * std::mem::size_of::() + } +} + #[cfg(test)] mod tests { use super::*; From 7b2015584c471ddd6d0f35fb1c7224615c203fec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 3 Jul 2023 16:17:16 +0200 Subject: [PATCH 4/5] WIP sum --- datafusion/physical-expr/src/aggregate/sum.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 425901f91a90..d17f7ec97630 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -590,7 +590,7 @@ where ) -> Result<()> { assert_eq!(values.len(), 1, "two arguments to merge_batch"); // first batch is partial sums - let partial_sums: &PrimitiveArray = values.get(1).unwrap().as_primitive::(); + let partial_sums: &PrimitiveArray = values.get(0).unwrap().as_primitive::(); self.update_sums(group_indices, partial_sums, opt_filter, total_num_groups); Ok(()) From 6275a9faaae5c36df76941d94b2ba9e92b075bae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 3 Jul 2023 17:20:07 +0200 Subject: [PATCH 5/5] Use `Rows` API --- Cargo.toml | 9 ++++++++ .../src/physical_plan/aggregates/row_hash2.rs | 22 ++++++++----------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b5d0a34e7e4d..3d6f5aed88b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,3 +70,12 @@ opt-level = 3 overflow-checks = false panic = 'unwind' rpath = false + +# TODO remove after 43 release +[patch.crates-io] +arrow = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" } +arrow-flight = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" } +arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" } +arrow-buffer = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" } +arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" } +parquet = { git = "https://github.com/apache/arrow-rs.git", rev = "d7fa775cf76c7cd54c6d2a86542115599d8f53ee" } diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs index c248af8f44f2..0a4f1d2ba881 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs @@ -26,7 +26,7 @@ use std::task::{Context, Poll}; use std::vec; use ahash::RandomState; -use arrow::row::{OwnedRow, RowConverter, SortField}; +use arrow::row::{RowConverter, SortField, Rows}; use datafusion_physical_expr::hash_utils::create_hashes; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -163,11 +163,7 @@ pub(crate) struct GroupedHashAggregateStream2 { /// /// The row format is used to compare group keys quickly. This is /// especially important for multi-column group keys. - /// - /// TODO, make this Rows (rather than Vec to reduce - /// allocations once - /// https://github.com/apache/arrow-rs/issues/4466 is available - group_values: Vec, + group_values: Rows, /// scratch space for the current input Batch being /// processed. Reused across batches here to avoid reallocations @@ -239,7 +235,7 @@ impl GroupedHashAggregateStream2 { let name = format!("GroupedHashAggregateStream2[{partition}]"); let reservation = MemoryConsumer::new(name).register(context.memory_pool()); let map = RawTable::with_capacity(0); - let group_by_values = vec![]; + let group_by_values = row_converter.empty_rows(0, 0); let current_group_indices = vec![]; timer.done(); @@ -381,7 +377,7 @@ impl GroupedHashAggregateStream2 { // TODO update *allocated based on size of the row // that was just pushed into // aggr_state.group_by_values - group_rows.row(row) == self.group_values[*group_idx].row() + group_rows.row(row) == self.group_values.row(*group_idx) }); let group_idx = match entry { @@ -390,8 +386,8 @@ impl GroupedHashAggregateStream2 { // 1.2 Need to create new entry for the group None => { // Add new entry to aggr_state and save newly created index - let group_idx = self.group_values.len(); - self.group_values.push(group_rows.row(row).owned()); + let group_idx = self.group_values.num_rows(); + self.group_values.push(group_rows.row(row)); // for hasher function, use precomputed hash value self.map.insert_accounted( @@ -438,7 +434,7 @@ impl GroupedHashAggregateStream2 { .zip(input_values.iter()) .zip(filter_values.iter()); - let total_num_groups = self.group_values.len(); + let total_num_groups = self.group_values.num_rows(); for ((acc, values), opt_filter) in t { let acc_size_pre = acc.size(); @@ -482,13 +478,13 @@ impl GroupedHashAggregateStream2 { impl GroupedHashAggregateStream2 { /// Create an output RecordBatch with all group keys and accumulator states/values fn create_batch_from_map(&mut self) -> Result { - if self.group_values.is_empty() { + if self.group_values.num_rows() == 0 { let schema = self.schema.clone(); return Ok(RecordBatch::new_empty(schema)); } // First output rows are the groups - let groups_rows = self.group_values.iter().map(|owned_row| owned_row.row()); + let groups_rows = self.group_values.iter().map(|owned_row| owned_row); let mut output: Vec = self.row_converter.convert_rows(groups_rows)?;