Skip to content

Commit

Permalink
feat: add ewm_mean_by (#15638)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Apr 15, 2024
1 parent 8ef2e21 commit 1e7fa8e
Show file tree
Hide file tree
Showing 19 changed files with 710 additions and 1 deletion.
1 change: 1 addition & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ abs = ["polars-plan/abs"]
random = ["polars-plan/random"]
dynamic_group_by = ["polars-plan/dynamic_group_by", "polars-time", "temporal"]
ewma = ["polars-plan/ewma"]
ewma_by = ["polars-plan/ewma_by"]
dot_diagram = ["polars-plan/dot_diagram"]
diagonal_concat = []
unique_counts = ["polars-plan/unique_counts"]
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ repeat_by = []
peaks = []
cum_agg = []
ewma = []
ewma_by = []
abs = []
cov = []
gather = []
Expand Down
176 changes: 176 additions & 0 deletions crates/polars-ops/src/series/ops/ewm_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use num_traits::{Float, FromPrimitive, One, Zero};
use polars_core::prelude::*;

pub fn ewm_mean_by(
s: &Series,
times: &Series,
half_life: i64,
assume_sorted: bool,
) -> PolarsResult<Series> {
let func = match assume_sorted {
true => ewm_mean_by_impl_sorted,
false => ewm_mean_by_impl,
};
match (s.dtype(), times.dtype()) {
(DataType::Float64, DataType::Int64) => {
Ok(func(s.f64().unwrap(), times.i64().unwrap(), half_life).into_series())
},
(DataType::Float32, DataType::Int64) => {
Ok(ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life).into_series())
},
#[cfg(feature = "dtype-datetime")]
(_, DataType::Datetime(time_unit, _)) => {
let half_life = adjust_half_life_to_time_unit(half_life, time_unit);
ewm_mean_by(s, &times.cast(&DataType::Int64)?, half_life, assume_sorted)
},
#[cfg(feature = "dtype-date")]
(_, DataType::Date) => ewm_mean_by(
s,
&times.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?,
half_life,
assume_sorted,
),
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => {
ewm_mean_by(s, &times.cast(&DataType::Int64)?, half_life, assume_sorted)
},
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
ewm_mean_by(
&s.cast(&DataType::Float64)?,
times,
half_life,
assume_sorted,
)
},
_ => {
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \
UInt64, or UInt32")
},
}
}

/// Sort on behalf of user
fn ewm_mean_by_impl<T>(
values: &ChunkedArray<T>,
times: &Int64Chunked,
half_life: i64,
) -> ChunkedArray<T>
where
T: PolarsFloatType,
T::Native: Float + Zero + One,
ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
{
let sorting_indices = times.arg_sort(Default::default());
let values = unsafe { values.take_unchecked(&sorting_indices) };
let times = unsafe { times.take_unchecked(&sorting_indices) };
let sorting_indices = sorting_indices
.cont_slice()
.expect("`arg_sort` should have returned a single chunk");

let mut out = vec![None; times.len()];

let mut skip_rows: usize = 0;
let mut prev_time: i64 = 0;
let mut prev_result = T::Native::zero();
for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {
if let (Some(time), Some(value)) = (time, value) {
prev_time = time;
prev_result = value;
unsafe {
let out_idx = sorting_indices.get_unchecked(idx);
*out.get_unchecked_mut(*out_idx as usize) = Some(prev_result);
}
skip_rows = idx + 1;
break;
};
}
values
.iter()
.zip(times.iter())
.enumerate()
.skip(skip_rows)
.for_each(|(idx, (value, time))| {
let result_opt = match (time, value) {
(Some(time), Some(value)) => {
let result = update(value, prev_result, time, prev_time, half_life);
prev_time = time;
prev_result = result;
Some(result)
},
_ => None,
};
unsafe {
let out_idx = sorting_indices.get_unchecked(idx);
*out.get_unchecked_mut(*out_idx as usize) = result_opt;
}
});
ChunkedArray::<T>::from_iter_options(values.name(), out.into_iter())
}

/// Fastpath if `times` is known to already be sorted.
fn ewm_mean_by_impl_sorted<T>(
values: &ChunkedArray<T>,
times: &Int64Chunked,
half_life: i64,
) -> ChunkedArray<T>
where
T: PolarsFloatType,
T::Native: Float + Zero + One,
{
let mut out = Vec::with_capacity(times.len());

let mut skip_rows: usize = 0;
let mut prev_time: i64 = 0;
let mut prev_result = T::Native::zero();
for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {
if let (Some(time), Some(value)) = (time, value) {
prev_time = time;
prev_result = value;
out.push(Some(prev_result));
skip_rows = idx + 1;
break;
};
}
values
.iter()
.zip(times.iter())
.skip(skip_rows)
.for_each(|(value, time)| {
let result_opt = match (time, value) {
(Some(time), Some(value)) => {
let result = update(value, prev_result, time, prev_time, half_life);
prev_time = time;
prev_result = result;
Some(result)
},
_ => None,
};
out.push(result_opt);
});
ChunkedArray::<T>::from_iter_options(values.name(), out.into_iter())
}

fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {
match time_unit {
TimeUnit::Milliseconds => half_life / 1_000_000,
TimeUnit::Microseconds => half_life / 1_000,
TimeUnit::Nanoseconds => half_life,
}
}

fn update<T>(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T
where
T: Float + Zero + One + FromPrimitive,
{
if value != prev_result {
let delta_time = time - prev_time;
// equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life)
let one_minus_alpha = T::from_f64(0.5)
.unwrap()
.powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap());
let alpha = T::one() - one_minus_alpha;
alpha * value + one_minus_alpha * prev_result
} else {
value
}
}
4 changes: 4 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod cut;
mod diff;
#[cfg(feature = "ewma")]
mod ewm;
#[cfg(feature = "ewma_by")]
mod ewm_by;
#[cfg(feature = "round_series")]
mod floor_divide;
#[cfg(feature = "fused")]
Expand Down Expand Up @@ -78,6 +80,8 @@ pub use cut::*;
pub use diff::*;
#[cfg(feature = "ewma")]
pub use ewm::*;
#[cfg(feature = "ewma_by")]
pub use ewm_by::*;
#[cfg(feature = "round_series")]
pub use floor_divide::*;
#[cfg(feature = "fused")]
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ abs = ["polars-ops/abs"]
random = ["polars-core/random"]
dynamic_group_by = ["polars-core/dynamic_group_by"]
ewma = ["polars-ops/ewma"]
ewma_by = ["polars-ops/ewma_by"]
dot_diagram = []
unique_counts = ["polars-ops/unique_counts"]
log = ["polars-ops/log"]
Expand Down Expand Up @@ -205,6 +206,7 @@ features = [
"cutqcut",
"async",
"ewma",
"ewma_by",
"random",
"chunked_ids",
"repeat_by",
Expand Down
25 changes: 25 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/ewm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@ pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult<Series>
polars_ops::prelude::ewm_mean(s, options)
}

pub(super) fn ewm_mean_by(
s: &[Series],
half_life: Duration,
check_sorted: bool,
) -> PolarsResult<Series> {
let time_zone = match s[1].dtype() {
DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()),
_ => None,
};
polars_ensure!(!half_life.negative(), InvalidOperation: "half_life cannot be negative");
polars_ensure!(half_life.is_constant_duration(time_zone),
InvalidOperation: "expected `half_life` to be a constant duration \
(i.e. one independent of differing month durations or of daylight savings time), got {}.\n\
\n\
You may want to try:\n\
- using `'730h'` instead of `'1mo'`\n\
- using `'24h'` instead of `'1d'` if your series is time-zone-aware", half_life);
// `half_life` is a constant duration so we can safely use `duration_ns()`.
let half_life = half_life.duration_ns();
let values = &s[0];
let times = &s[1];
let assume_sorted = !check_sorted || times.is_sorted_flag() == IsSorted::Ascending;
polars_ops::prelude::ewm_mean_by(values, times, half_life, assume_sorted)
}

pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult<Series> {
polars_ops::prelude::ewm_std(s, options)
}
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ pub enum FunctionExpr {
EwmMean {
options: EWMOptions,
},
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life: Duration,
check_sorted: bool,
},
#[cfg(feature = "ewma")]
EwmStd {
options: EWMOptions,
Expand Down Expand Up @@ -520,6 +525,11 @@ impl Hash for FunctionExpr {
BackwardFill { limit } | ForwardFill { limit } => limit.hash(state),
#[cfg(feature = "ewma")]
EwmMean { options } => options.hash(state),
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life,
check_sorted,
} => (half_life, check_sorted).hash(state),
#[cfg(feature = "ewma")]
EwmStd { options } => options.hash(state),
#[cfg(feature = "ewma")]
Expand Down Expand Up @@ -705,6 +715,8 @@ impl Display for FunctionExpr {
MeanHorizontal => "mean_horizontal",
#[cfg(feature = "ewma")]
EwmMean { .. } => "ewm_mean",
#[cfg(feature = "ewma_by")]
EwmMeanBy { .. } => "ewm_mean_by",
#[cfg(feature = "ewma")]
EwmStd { .. } => "ewm_std",
#[cfg(feature = "ewma")]
Expand Down Expand Up @@ -1073,6 +1085,11 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
MeanHorizontal => wrap!(dispatch::mean_horizontal),
#[cfg(feature = "ewma")]
EwmMean { options } => map!(ewm::ewm_mean, options),
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life,
check_sorted,
} => map_as_slice!(ewm::ewm_mean_by, half_life, check_sorted),
#[cfg(feature = "ewma")]
EwmStd { options } => map!(ewm::ewm_std, options),
#[cfg(feature = "ewma")]
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ impl FunctionExpr {
MeanHorizontal => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma")]
EwmMean { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma_by")]
EwmMeanBy { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma")]
EwmStd { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma")]
Expand Down
14 changes: 14 additions & 0 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,20 @@ impl Expr {
self.apply_private(FunctionExpr::EwmMean { options })
}

#[cfg(feature = "ewma_by")]
/// Calculate the exponentially-weighted moving average by a time column.
pub fn ewm_mean_by(self, times: Expr, half_life: Duration, check_sorted: bool) -> Self {
self.apply_many_private(
FunctionExpr::EwmMeanBy {
half_life,
check_sorted,
},
&[times],
false,
false,
)
}

#[cfg(feature = "ewma")]
/// Calculate the exponentially-weighted moving standard deviation.
pub fn ewm_std(self, options: EWMOptions) -> Self {
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-time/src/windows/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use arrow::legacy::kernels::{Ambiguous, NonExistent};
use arrow::legacy::time_zone::Tz;
use arrow::temporal_conversions::{
timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, MILLISECONDS,
NANOSECONDS,
};
use chrono::{Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
use polars_core::export::arrow::temporal_conversions::MICROSECONDS;
use polars_core::prelude::{
datetime_to_timestamp_ms, datetime_to_timestamp_ns, datetime_to_timestamp_us, polars_bail,
PolarsResult,
};
use polars_core::utils::arrow::temporal_conversions::NANOSECONDS;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -378,6 +378,11 @@ impl Duration {
self.nsecs
}

/// Returns whether duration is negative.
pub fn negative(&self) -> bool {
self.negative
}

/// Estimated duration of the window duration. Not a very good one if not a constant duration.
#[doc(hidden)]
pub const fn duration_ns(&self) -> i64 {
Expand Down
1 change: 1 addition & 0 deletions crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ dot_diagram = ["polars-lazy?/dot_diagram"]
dot_product = ["polars-core/dot_product"]
dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"]
ewma = ["polars-ops/ewma", "polars-lazy?/ewma"]
ewma_by = ["polars-ops/ewma_by", "polars-lazy?/ewma_by"]
extract_groups = ["polars-lazy?/extract_groups"]
extract_jsonpath = [
"polars-core/strings",
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ features = [
"dtype-full",
"dynamic_group_by",
"ewma",
"ewma_by",
"fmt",
"interpolate",
"is_first_distinct",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Computation
Expr.dot
Expr.entropy
Expr.ewm_mean
Expr.ewm_mean_by
Expr.ewm_std
Expr.ewm_var
Expr.exp
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Computation
Series.dot
Series.entropy
Series.ewm_mean
Series.ewm_mean_by
Series.ewm_std
Series.ewm_var
Series.exp
Expand Down
Loading

0 comments on commit 1e7fa8e

Please sign in to comment.