Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 15, 2024
1 parent f400301 commit 9c0844b
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 27 deletions.
33 changes: 23 additions & 10 deletions crates/polars-ops/src/series/ops/ewm_by.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use num_traits::{Float, FromPrimitive, One, Zero};
use polars_core::prelude::*;

pub fn ewm_mean_by(s: &Series, times: &Series, half_life: i64, assume_sorted: bool) -> PolarsResult<Series> {
pub fn ewm_mean_by(
s: &Series,
times: &Series,
half_life: i64,
assume_sorted: bool,
) -> PolarsResult<Series> {
let func = match assume_sorted {
true => ewm_mean_by_impl_sorted,
false => ewm_mean_by_impl,
Expand Down Expand Up @@ -30,7 +34,12 @@ pub fn ewm_mean_by(s: &Series, times: &Series, half_life: i64, assume_sorted: bo
ewm_mean_by(s, &times.cast(&DataType::Int64)?, half_life, assume_sorted)
},
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
ewm_mean_by(&s.cast(&DataType::Float64)?, times, half_life, assume_sorted)
ewm_mean_by(
&s.cast(&DataType::Float64)?,
times,
half_life,
assume_sorted,
)
},
_ => {
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
Expand All @@ -40,15 +49,16 @@ pub fn ewm_mean_by(s: &Series, times: &Series, half_life: i64, assume_sorted: bo
}
}

/// Sort for the
/// Sort on behalf of user
fn ewm_mean_by_impl<T>(
values: &ChunkedArray<T>,
times: &Int64Chunked,
half_life: i64,
) -> ChunkedArray<T> where
) -> ChunkedArray<T>
where
T: PolarsFloatType,
T::Native: Float + Zero + One,
ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>
ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
{
let sorting_indices = times.arg_sort(Default::default());
let values = unsafe { values.take_unchecked(&sorting_indices) };
Expand Down Expand Up @@ -99,7 +109,8 @@ fn ewm_mean_by_impl_sorted<T>(
values: &ChunkedArray<T>,
times: &Int64Chunked,
half_life: i64,
) -> ChunkedArray<T> where
) -> ChunkedArray<T>
where
T: PolarsFloatType,
T::Native: Float + Zero + One,
{
Expand Down Expand Up @@ -147,11 +158,13 @@ fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {
fn update<T>(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T
where
T: Float + Zero + One + FromPrimitive,
{
{
if value != prev_result {
let delta_time = time - prev_time;
// equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life)
let one_minus_alpha = T::from_f64(0.5).unwrap().powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap());
let one_minus_alpha = T::from_f64(0.5)
.unwrap()
.powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap());
let alpha = T::one() - one_minus_alpha;
alpha * value + one_minus_alpha * prev_result
} else {
Expand Down
10 changes: 8 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/ewm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult<Series>
polars_ops::prelude::ewm_mean(s, options)
}

pub(super) fn ewm_mean_by(s: &[Series], half_life: Duration, check_sorted: bool) -> PolarsResult<Series> {
pub(super) fn ewm_mean_by(
s: &[Series],
half_life: Duration,
check_sorted: bool,
) -> PolarsResult<Series> {
let time_zone = match s[1].dtype() {
DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()),
_ => None,
Expand All @@ -22,8 +26,10 @@ pub(super) fn ewm_mean_by(s: &[Series], half_life: Duration, check_sorted: bool)
}
// `half_life` is a constant duration so we can safely use `duration_ns()`.
let half_life = half_life.duration_ns();
let values = &s[0];
let times = &s[1];
let assume_sorted = !check_sorted || times.is_sorted_flag() == IsSorted::Ascending;
polars_ops::prelude::ewm_mean_by(&s[0], &s[1], half_life, assume_sorted)
polars_ops::prelude::ewm_mean_by(values, times, half_life, assume_sorted)
}

pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult<Series> {
Expand Down
10 changes: 8 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,10 @@ impl Hash for FunctionExpr {
#[cfg(feature = "ewma")]
EwmMean { options } => options.hash(state),
#[cfg(feature = "ewma_by")]
EwmMeanBy { half_life , check_sorted } => (half_life, check_sorted).hash(state),
EwmMeanBy {
half_life,
check_sorted,
} => (half_life, check_sorted).hash(state),
#[cfg(feature = "ewma")]
EwmStd { options } => options.hash(state),
#[cfg(feature = "ewma")]
Expand Down Expand Up @@ -1083,7 +1086,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "ewma")]
EwmMean { options } => map!(ewm::ewm_mean, options),
#[cfg(feature = "ewma_by")]
EwmMeanBy { half_life , check_sorted } => map_as_slice!(ewm::ewm_mean_by, half_life, check_sorted),
EwmMeanBy {
half_life,
check_sorted,
} => map_as_slice!(ewm::ewm_mean_by, half_life, check_sorted),
#[cfg(feature = "ewma")]
EwmStd { options } => map!(ewm::ewm_std, options),
#[cfg(feature = "ewma")]
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,10 @@ impl Expr {
/// Calculate the exponentially-weighted moving average by a time column.
pub fn ewm_mean_by(self, times: Expr, half_life: Duration, check_sorted: bool) -> Self {
self.apply_many_private(
FunctionExpr::EwmMeanBy { half_life, check_sorted },
FunctionExpr::EwmMeanBy {
half_life,
check_sorted,
},
&[times],
false,
false,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8846,7 +8846,7 @@ def ewm_mean_by(
"""
by = parse_as_expression(by)
half_life = parse_as_duration_string(half_life)
return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life, check_sorted))
return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life))

@deprecate_nonkeyword_arguments(version="0.19.10")
def ewm_std(
Expand Down
13 changes: 2 additions & 11 deletions py-polars/tests/unit/operations/test_ewm_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_ewma_by_date_constant() -> None:
assert_frame_equal(result, expected)


@pytest.mark.xfail()
def test_ewma_f32() -> None:
df = pl.LazyFrame(
{
Expand Down Expand Up @@ -171,20 +170,12 @@ def test_ewma_by_empty() -> None:
assert_frame_equal(result, expected)


@pytest.mark.xfail()
def test_ewma_by_warn_if_unsorted() -> None:
df = pl.DataFrame({"values": [2.0, 3.0], "by": [1, 3]})
with pytest.warns(
UserWarning, match="Series is not known to be sorted by `by` column"
):
result = df.select(
pl.col("values").ewm_mean_by("by", half_life="2i"),
)
expected = pl.DataFrame({"values": [2.0, 2.5]})
assert_frame_equal(result, expected)
result = df.select(
pl.col("values").ewm_mean_by("by", half_life="2i", warn_if_unsorted=False), # type: ignore[call-arg]
pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False),
)
expected = pl.DataFrame({"values": [2.0, 2.5]})
assert_frame_equal(result, expected)
result = df.sort("by").select(
pl.col("values").ewm_mean_by("by", half_life="2i"),
Expand Down

0 comments on commit 9c0844b

Please sign in to comment.