Skip to content

Commit

Permalink
feat: add Expr.ewm_mean_by
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 14, 2024
1 parent 37c6303 commit 0ac478c
Show file tree
Hide file tree
Showing 21 changed files with 634 additions and 6 deletions.
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/temporal_conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::datatypes::{ArrowDataType, TimeUnit};

/// Number of seconds in a day
pub const SECONDS_IN_DAY: i64 = 86_400;
/// Number of seconds in an hour
pub const SECONDS_IN_HOUR: i64 = 3_600;
/// Number of milliseconds in a second
pub const MILLISECONDS: i64 = 1_000;
/// Number of microseconds in a second
Expand Down
1 change: 1 addition & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ abs = ["polars-plan/abs"]
random = ["polars-plan/random"]
dynamic_group_by = ["polars-plan/dynamic_group_by", "polars-time", "temporal"]
ewma = ["polars-plan/ewma"]
ewma_by = ["polars-plan/ewma_by"]
dot_diagram = ["polars-plan/dot_diagram"]
diagonal_concat = []
unique_counts = ["polars-plan/unique_counts"]
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ repeat_by = []
peaks = []
cum_agg = []
ewma = []
ewma_by = ["dtype-datetime", "dtype-date"]
abs = []
cov = []
gather = []
Expand Down
103 changes: 103 additions & 0 deletions crates/polars-ops/src/series/ops/ewm_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use num_traits::{Float, FromPrimitive, One, Zero};
use polars_core::prelude::*;

pub fn ewm_mean_by(s: &Series, times: &Series, half_life: i64) -> PolarsResult<Series> {
match (s.dtype(), times.dtype()) {
(DataType::Float64, DataType::Datetime(_, _)) => {
let times = times.datetime().unwrap();
let half_life = adjust_half_life_to_time_unit(half_life, times.time_unit());
Ok(ewm_mean_by_impl(s.f64().unwrap(), &times.0, half_life).into_series())
},
(DataType::Float32, DataType::Datetime(_, _)) => {
let times = times.datetime().unwrap();
let half_life = adjust_half_life_to_time_unit(half_life, times.time_unit());
Ok(ewm_mean_by_impl(s.f32().unwrap(), &times.0, half_life).into_series())
},
(
DataType::Int64 | DataType::Int32 | DataType::UInt64 | DataType::UInt32,
DataType::Datetime(_, _),
) => ewm_mean_by(&s.cast(&DataType::Float64)?, times, half_life),
(_, DataType::Date) => ewm_mean_by(
s,
&times.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?,
half_life,
),
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32) => {
ewm_mean_by(
s,
&times
.cast(&DataType::Int64)?
.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
half_life,
)
},
_ => {
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, Int64, Int32, UInt64, UInt32, and `by` to be Date or Datetime")
},
}
}

fn ewm_mean_by_impl<T>(
values: &ChunkedArray<T>,
times: &Int64Chunked,
half_life: i64,
) -> ChunkedArray<T>
where
T: PolarsFloatType,
T::Native: Zero + One + Float,
{
let mut out = Vec::with_capacity(times.len());

let mut skip_rows: usize = 0;
let mut prev_time: i64 = 0;
let mut prev_result = T::Native::zero();
for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {
match (time, value) {
(Some(time), Some(value)) => {
prev_time = time;
prev_result = value;
out.push(Some(prev_result));
skip_rows = idx + 1;
break;
},
_ => {
out.push(None);
},
};
}
values
.iter()
.zip(times.iter())
.skip(skip_rows)
.for_each(|(value, time)| {
match (time, value) {
(Some(time), Some(value)) => {
let result = if value != prev_result {
let delta_time = time - prev_time;
// equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life)
let one_minus_alpha = (T::Native::from_f64(0.5).unwrap()).powf(
T::Native::from_i64(delta_time).unwrap()
/ T::Native::from_i64(half_life).unwrap(),
);
let alpha = T::Native::one() - one_minus_alpha;
alpha * value + one_minus_alpha * prev_result
} else {
value
};
prev_time = time;
prev_result = result;
out.push(Some(result));
},
_ => out.push(None),
}
});
ChunkedArray::<T>::from_iter_options(values.name(), out.into_iter())
}

fn adjust_half_life_to_time_unit(half_life: i64, time_unit: TimeUnit) -> i64 {
match time_unit {
TimeUnit::Milliseconds => half_life / 1_000_000,
TimeUnit::Microseconds => half_life / 1_000,
TimeUnit::Nanoseconds => half_life,
}
}
4 changes: 4 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod cut;
mod diff;
#[cfg(feature = "ewma")]
mod ewm;
#[cfg(feature = "ewma_by")]
mod ewm_by;
#[cfg(feature = "round_series")]
mod floor_divide;
#[cfg(feature = "fused")]
Expand Down Expand Up @@ -78,6 +80,8 @@ pub use cut::*;
pub use diff::*;
#[cfg(feature = "ewma")]
pub use ewm::*;
#[cfg(feature = "ewma_by")]
pub use ewm_by::*;
#[cfg(feature = "round_series")]
pub use floor_divide::*;
#[cfg(feature = "fused")]
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ abs = ["polars-ops/abs"]
random = ["polars-core/random"]
dynamic_group_by = ["polars-core/dynamic_group_by"]
ewma = ["polars-ops/ewma"]
ewma_by = ["polars-ops/ewma_by"]
dot_diagram = []
unique_counts = ["polars-ops/unique_counts"]
log = ["polars-ops/log"]
Expand Down Expand Up @@ -205,6 +206,7 @@ features = [
"cutqcut",
"async",
"ewma",
"ewma_by",
"random",
"chunked_ids",
"repeat_by",
Expand Down
37 changes: 37 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/ewm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,43 @@ pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult<Series>
polars_ops::prelude::ewm_mean(s, options)
}

pub(super) fn ewm_mean_by(
s: &[Series],
half_life: Duration,
warn_if_unsorted: bool,
) -> PolarsResult<Series> {
let time_zone = match s[1].dtype() {
DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()),
_ => None,
};
if half_life.negative() {
polars_bail!(InvalidOperation: "half_life cannot be negative");
}
if !half_life.is_constant_duration(time_zone) {
polars_bail!(InvalidOperation: "expected `half_life` to be a constant duration \
(i.e. one independent of differing month durations or of daylight savings time), got {}.\n\n
You may want to try:\n\
- using `'730h'` instead of `'1mo'`;\n\
- using `'24h'` instead of `'1d'` if your series is time-zone-aware;
", half_life);
}
if s[1].is_sorted_flag() != IsSorted::Ascending && warn_if_unsorted {
polars_warn!(format!(
"Series is not known to be sorted by `by` column in {} operation.\n\
\n\
To silence this warning, you may want to try:\n\
- sorting your data by your `by` column beforehand;\n\
- setting `.set_sorted()` if you already know your data is sorted;\n\
- passing `warn_if_unsorted=False` if this warning is a false-positive\n \
(this is known to happen when combining rolling aggregations with `over`);\n\n\
before passing calling the rolling aggregation function.\n",
"ewm_mean_by",
));
}
// `half_life` is a constant duration so we can safely use `duration_ns()`.
polars_ops::prelude::ewm_mean_by(&s[0], &s[1], half_life.duration_ns())
}

pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult<Series> {
polars_ops::prelude::ewm_std(s, options)
}
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ pub enum FunctionExpr {
EwmMean {
options: EWMOptions,
},
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life: Duration,
warn_if_unsorted: bool,
},
#[cfg(feature = "ewma")]
EwmStd {
options: EWMOptions,
Expand Down Expand Up @@ -520,6 +525,11 @@ impl Hash for FunctionExpr {
BackwardFill { limit } | ForwardFill { limit } => limit.hash(state),
#[cfg(feature = "ewma")]
EwmMean { options } => options.hash(state),
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life,
warn_if_unsorted,
} => (half_life, warn_if_unsorted).hash(state),
#[cfg(feature = "ewma")]
EwmStd { options } => options.hash(state),
#[cfg(feature = "ewma")]
Expand Down Expand Up @@ -705,6 +715,8 @@ impl Display for FunctionExpr {
MeanHorizontal => "mean_horizontal",
#[cfg(feature = "ewma")]
EwmMean { .. } => "ewm_mean",
#[cfg(feature = "ewma_by")]
EwmMeanBy { .. } => "ewm_mean_by",
#[cfg(feature = "ewma")]
EwmStd { .. } => "ewm_std",
#[cfg(feature = "ewma")]
Expand Down Expand Up @@ -1073,6 +1085,11 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
MeanHorizontal => wrap!(dispatch::mean_horizontal),
#[cfg(feature = "ewma")]
EwmMean { options } => map!(ewm::ewm_mean, options),
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life,
warn_if_unsorted,
} => map_as_slice!(ewm::ewm_mean_by, half_life, warn_if_unsorted),
#[cfg(feature = "ewma")]
EwmStd { options } => map!(ewm::ewm_std, options),
#[cfg(feature = "ewma")]
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ impl FunctionExpr {
MeanHorizontal => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma")]
EwmMean { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma_by")]
EwmMeanBy { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma")]
EwmStd { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "ewma")]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ pub(super) fn date_offset(s: &[Series]) -> PolarsResult<Series> {
let offset = Duration::parse(offset);
tz.is_none()
|| tz.as_deref() == Some("UTC")
|| offset.is_constant_duration()
|| offset.is_constant_duration(tz.as_deref())
},
None => false,
},
Expand Down
14 changes: 14 additions & 0 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,20 @@ impl Expr {
self.apply_private(FunctionExpr::EwmMean { options })
}

#[cfg(feature = "ewma_by")]
/// Calculate the exponentially-weighted moving average by a time column.
pub fn ewm_mean_by(self, times: Expr, half_life: Duration, warn_if_unsorted: bool) -> Self {
self.apply_many_private(
FunctionExpr::EwmMeanBy {
half_life,
warn_if_unsorted,
},
&[times],
false,
false,
)
}

#[cfg(feature = "ewma")]
/// Calculate the exponentially-weighted moving standard deviation.
pub fn ewm_std(self, options: EWMOptions) -> Self {
Expand Down
21 changes: 16 additions & 5 deletions crates/polars-time/src/windows/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use arrow::legacy::kernels::{Ambiguous, NonExistent};
use arrow::legacy::time_zone::Tz;
use arrow::temporal_conversions::{
timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, MILLISECONDS,
NANOSECONDS,
};
use chrono::{Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
use polars_core::export::arrow::temporal_conversions::MICROSECONDS;
use polars_core::prelude::{
datetime_to_timestamp_ms, datetime_to_timestamp_ns, datetime_to_timestamp_us, polars_bail,
PolarsResult,
};
use polars_core::utils::arrow::temporal_conversions::NANOSECONDS;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand All @@ -32,7 +32,7 @@ pub struct Duration {
months: i64,
// the number of weeks for the duration
weeks: i64,
// the number of nanoseconds for the duration
// the number of days for the duration
days: i64,
// the number of nanoseconds for the duration
nsecs: i64,
Expand Down Expand Up @@ -363,16 +363,27 @@ impl Duration {
self.nsecs == 0
}

pub fn is_constant_duration(&self) -> bool {
self.months == 0 && self.weeks == 0 && self.days == 0
pub fn is_constant_duration(&self, time_zone: Option<&str>) -> bool {
if time_zone.is_none() || time_zone == Some("UTC") {
self.months == 0
} else {
// For non-native, non-UTC time zones, 1 calendar day is not
// necessarily 24 hours due to daylight savings time.
self.months == 0 && self.weeks == 0 && self.days == 0
}
}

/// Returns the nanoseconds from the `Duration` without the weeks or months part.
pub fn nanoseconds(&self) -> i64 {
self.nsecs
}

/// Estimated duration of the window duration. Not a very good one if months != 0.
/// Returns whether duration is negative.
pub fn negative(&self) -> bool {
self.negative
}

/// Estimated duration of the window duration. Not a very good one if not a constant duration.
#[doc(hidden)]
pub const fn duration_ns(&self) -> i64 {
self.months * 28 * 24 * 3600 * NANOSECONDS
Expand Down
1 change: 1 addition & 0 deletions crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ dot_diagram = ["polars-lazy?/dot_diagram"]
dot_product = ["polars-core/dot_product"]
dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"]
ewma = ["polars-ops/ewma", "polars-lazy?/ewma"]
ewma_by = ["polars-ops/ewma_by", "polars-lazy?/ewma_by"]
extract_groups = ["polars-lazy?/extract_groups"]
extract_jsonpath = [
"polars-core/strings",
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ features = [
"dtype-full",
"dynamic_group_by",
"ewma",
"ewma_by",
"fmt",
"interpolate",
"is_first_distinct",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Computation
Expr.dot
Expr.entropy
Expr.ewm_mean
Expr.ewm_mean_by
Expr.ewm_std
Expr.ewm_var
Expr.exp
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Computation
Series.dot
Series.entropy
Series.ewm_mean
Series.ewm_mean_by
Series.ewm_std
Series.ewm_var
Series.exp
Expand Down
Loading

0 comments on commit 0ac478c

Please sign in to comment.