From 09f26d8ec1726b1c3091ea0987409e8237573dd2 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 16 May 2024 12:18:48 +0100 Subject: [PATCH] perf: use zeroed vec in ewm_mean_by for sorted fastpath (#16265) --- crates/polars-core/src/utils/mod.rs | 18 +++++++++ crates/polars-ops/src/series/ops/ewm_by.rs | 45 ++++++++++------------ 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 7b23e95ffdfbc..b3fb970707d58 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -1,4 +1,6 @@ mod any_value; +use arrow::compute::concatenate::concatenate_validities; +use arrow::compute::utils::combine_validities_and; pub mod flatten; pub(crate) mod series; mod supertype; @@ -834,6 +836,22 @@ where } } +pub fn binary_concatenate_validities<'a, T, B>( + left: &'a ChunkedArray, + right: &'a ChunkedArray, +) -> Option +where + B: PolarsDataType, + T: PolarsDataType, +{ + let (left, right) = align_chunks_binary(left, right); + let left_chunk_refs: Vec<_> = left.chunks().iter().map(|c| &**c).collect(); + let left_validity = concatenate_validities(&left_chunk_refs); + let right_chunk_refs: Vec<_> = right.chunks().iter().map(|c| &**c).collect(); + let right_validity = concatenate_validities(&right_chunk_refs); + combine_validities_and(left_validity.as_ref(), right_validity.as_ref()) +} + pub trait IntoVec { fn into_vec(self) -> Vec; } diff --git a/crates/polars-ops/src/series/ops/ewm_by.rs b/crates/polars-ops/src/series/ops/ewm_by.rs index ce374f7575078..4b1d047269ef4 100644 --- a/crates/polars-ops/src/series/ops/ewm_by.rs +++ b/crates/polars-ops/src/series/ops/ewm_by.rs @@ -1,9 +1,7 @@ -use arrow::compute::concatenate::concatenate_validities; -use arrow::compute::utils::combine_validities_and; use bytemuck::allocation::zeroed_vec; use num_traits::{Float, FromPrimitive, One, Zero}; use polars_core::prelude::*; -use polars_core::utils::align_chunks_binary; +use polars_core::utils::binary_concatenate_validities; pub fn ewm_mean_by( s: &Series, @@ -108,12 +106,7 @@ where }); let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(true)); if (times.null_count() > 0) || (values.null_count() > 0) { - let (times, values) = align_chunks_binary(times, values); - let times_chunk_refs: Vec<_> = times.chunks().iter().map(|c| &**c).collect(); - let times_validity = concatenate_validities(×_chunk_refs); - let values_chunk_refs: Vec<_> = values.chunks().iter().map(|c| &**c).collect(); - let values_validity = concatenate_validities(&values_chunk_refs); - let validity = combine_validities_and(times_validity.as_ref(), values_validity.as_ref()); + let validity = binary_concatenate_validities(times, values); arr = arr.with_validity_typed(validity); } ChunkedArray::with_chunk(values.name(), arr) @@ -129,7 +122,7 @@ where T: PolarsFloatType, T::Native: Float + Zero + One, { - let mut out = Vec::with_capacity(times.len()); + let mut out: Vec<_> = zeroed_vec(times.len()); let mut skip_rows: usize = 0; let mut prev_time: i64 = 0; @@ -138,30 +131,34 @@ where if let (Some(time), Some(value)) = (time, value) { prev_time = time; prev_result = value; - out.push(Some(prev_result)); + unsafe { + *out.get_unchecked_mut(idx) = prev_result; + } skip_rows = idx + 1; break; - } else { - out.push(None) } } values .iter() .zip(times.iter()) + .enumerate() .skip(skip_rows) - .for_each(|(value, time)| { - let result_opt = match (time, value) { - (Some(time), Some(value)) => { - let result = update(value, prev_result, time, prev_time, half_life); - prev_time = time; - prev_result = result; - Some(result) - }, - _ => None, + .for_each(|(idx, (value, time))| { + if let (Some(time), Some(value)) = (time, value) { + let result = update(value, prev_result, time, prev_time, half_life); + prev_time = time; + prev_result = result; + unsafe { + *out.get_unchecked_mut(idx) = result; + } }; - out.push(result_opt); }); - ChunkedArray::::from_iter_options(values.name(), out.into_iter()) + let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(true)); + if (times.null_count() > 0) || (values.null_count() > 0) { + let validity = binary_concatenate_validities(times, values); + arr = arr.with_validity_typed(validity); + } + ChunkedArray::with_chunk(values.name(), arr) } fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {