From 7bab3146e5e0fc983ee142caf8fa74b42f764fae Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Fri, 27 Sep 2024 10:41:51 +0200 Subject: [PATCH 1/2] refactor: Divide `ChunkCompare` into `Eq` and `Ineq` variants Divide the `ChunkCompare` trait into two traits `ChunkCompareEq` and `ChunkCompareIneq`, which allows us to statistically verify that there are no calls to the inequality methods when these are not available (e.g. for `List`, `Array` and `Struct`). This makes error handling a lot better as well. For example, the following was a panic exception before. ```python import polars as pl a = pl.Series('a', [[1]], pl.Array(pl.Int8, 1)) b = pl.Series('b', [[1]], pl.Array(pl.Int8, 1)) c = a < b ``` Now, it returns: ``` polars.exceptions.InvalidOperationError: cannot perform '<' comparison between series 'a' of dtype: array[i8, 1] and series 'b' of dtype: array[i8, 1] ``` Fixes #18938. --- .../chunked_array/comparison/categorical.rs | 18 +- .../src/chunked_array/comparison/mod.rs | 75 +++---- .../src/chunked_array/comparison/scalar.rs | 24 ++- .../polars-core/src/chunked_array/ops/mod.rs | 29 +-- .../src/chunked_array/ops/unique/mod.rs | 3 +- crates/polars-core/src/frame/column/mod.rs | 22 +- crates/polars-core/src/series/comparison.rs | 203 +++++++++++++----- crates/polars-core/src/series/mod.rs | 2 +- crates/polars-expr/src/expressions/apply.rs | 16 +- crates/polars-expr/src/expressions/binary.rs | 20 +- .../src/chunked_array/array/count.rs | 2 +- .../src/chunked_array/list/count.rs | 2 +- crates/polars-ops/src/chunked_array/peaks.rs | 4 +- 13 files changed, 262 insertions(+), 158 deletions(-) diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index faa7f619cdb2..bbcd6b6047c9 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -96,7 +96,7 @@ where } } -impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { +impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked { type Item = PolarsResult; fn equal(&self, rhs: &CategoricalChunked) -> Self::Item { @@ -134,6 +134,10 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { UInt32Chunked::not_equal_missing, ) } +} + +impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked { + type Item = PolarsResult; fn gt(&self, rhs: &CategoricalChunked) -> Self::Item { cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r) @@ -217,7 +221,7 @@ where } } -impl ChunkCompare<&StringChunked> for CategoricalChunked { +impl ChunkCompareEq<&StringChunked> for CategoricalChunked { type Item = PolarsResult; fn equal(&self, rhs: &StringChunked) -> Self::Item { @@ -265,6 +269,10 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { StringChunked::not_equal_missing, ) } +} + +impl ChunkCompareIneq<&StringChunked> for CategoricalChunked { + type Item = PolarsResult; fn gt(&self, rhs: &StringChunked) -> Self::Item { cat_str_compare_helper( @@ -376,7 +384,7 @@ where } } -impl ChunkCompare<&str> for CategoricalChunked { +impl ChunkCompareEq<&str> for CategoricalChunked { type Item = PolarsResult; fn equal(&self, rhs: &str) -> Self::Item { @@ -414,6 +422,10 @@ impl ChunkCompare<&str> for CategoricalChunked { UInt32Chunked::equal_missing, ) } +} + +impl ChunkCompareIneq<&str> for CategoricalChunked { + type Item = PolarsResult; fn gt(&self, rhs: &str) -> Self::Item { cat_single_str_compare_helper( diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 300f5f338cff..ecf8f78fcd9a 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -16,7 +16,7 @@ use crate::series::implementations::null::NullChunked; use crate::series::IsSorted; use crate::utils::align_chunks_binary; -impl ChunkCompare<&ChunkedArray> for ChunkedArray +impl ChunkCompareEq<&ChunkedArray> for ChunkedArray where T: PolarsNumericType, T::Array: TotalOrdKernel + TotalEqKernel, @@ -126,6 +126,14 @@ where ), } } +} + +impl ChunkCompareIneq<&ChunkedArray> for ChunkedArray +where + T: PolarsNumericType, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; fn lt(&self, rhs: &ChunkedArray) -> BooleanChunked { // Broadcast. @@ -188,7 +196,7 @@ where } } -impl ChunkCompare<&NullChunked> for NullChunked { +impl ChunkCompareEq<&NullChunked> for NullChunked { type Item = BooleanChunked; fn equal(&self, rhs: &NullChunked) -> Self::Item { @@ -206,6 +214,10 @@ impl ChunkCompare<&NullChunked> for NullChunked { fn not_equal_missing(&self, rhs: &NullChunked) -> Self::Item { BooleanChunked::full(self.name().clone(), false, get_broadcast_length(self, rhs)) } +} + +impl ChunkCompareIneq<&NullChunked> for NullChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &NullChunked) -> Self::Item { BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) @@ -234,7 +246,7 @@ fn get_broadcast_length(lhs: &NullChunked, rhs: &NullChunked) -> usize { } } -impl ChunkCompare<&BooleanChunked> for BooleanChunked { +impl ChunkCompareEq<&BooleanChunked> for BooleanChunked { type Item = BooleanChunked; fn equal(&self, rhs: &BooleanChunked) -> BooleanChunked { @@ -348,6 +360,10 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { ), } } +} + +impl ChunkCompareIneq<&BooleanChunked> for BooleanChunked { + type Item = BooleanChunked; fn lt(&self, rhs: &BooleanChunked) -> BooleanChunked { // Broadcast. @@ -410,7 +426,7 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { } } -impl ChunkCompare<&StringChunked> for StringChunked { +impl ChunkCompareEq<&StringChunked> for StringChunked { type Item = BooleanChunked; fn equal(&self, rhs: &StringChunked) -> BooleanChunked { @@ -424,9 +440,14 @@ impl ChunkCompare<&StringChunked> for StringChunked { fn not_equal(&self, rhs: &StringChunked) -> BooleanChunked { self.as_binary().not_equal(&rhs.as_binary()) } + fn not_equal_missing(&self, rhs: &StringChunked) -> BooleanChunked { self.as_binary().not_equal_missing(&rhs.as_binary()) } +} + +impl ChunkCompareIneq<&StringChunked> for StringChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &StringChunked) -> BooleanChunked { self.as_binary().gt(&rhs.as_binary()) @@ -445,7 +466,7 @@ impl ChunkCompare<&StringChunked> for StringChunked { } } -impl ChunkCompare<&BinaryChunked> for BinaryChunked { +impl ChunkCompareEq<&BinaryChunked> for BinaryChunked { type Item = BooleanChunked; fn equal(&self, rhs: &BinaryChunked) -> BooleanChunked { @@ -551,6 +572,10 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { ), } } +} + +impl ChunkCompareIneq<&BinaryChunked> for BinaryChunked { + type Item = BooleanChunked; fn lt(&self, rhs: &BinaryChunked) -> BooleanChunked { // Broadcast. @@ -644,7 +669,7 @@ where } } -impl ChunkCompare<&ListChunked> for ListChunked { +impl ChunkCompareEq<&ListChunked> for ListChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ListChunked) -> BooleanChunked { let _series_equals = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { @@ -684,23 +709,6 @@ impl ChunkCompare<&ListChunked> for ListChunked { _list_comparison_helper(self, rhs, _series_not_equal_missing) } - - // The following are not implemented because gt, lt comparison of series don't make sense. - fn gt(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } - - fn gt_eq(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt_eq(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } } #[cfg(feature = "dtype-struct")] @@ -741,7 +749,7 @@ where } #[cfg(feature = "dtype-struct")] -impl ChunkCompare<&StructChunked> for StructChunked { +impl ChunkCompareEq<&StructChunked> for StructChunked { type Item = BooleanChunked; fn equal(&self, rhs: &StructChunked) -> BooleanChunked { struct_helper( @@ -785,7 +793,7 @@ impl ChunkCompare<&StructChunked> for StructChunked { } #[cfg(feature = "dtype-array")] -impl ChunkCompare<&ArrayChunked> for ArrayChunked { +impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked { if self.width() != rhs.width() { @@ -834,23 +842,6 @@ impl ChunkCompare<&ArrayChunked> for ArrayChunked { PlSmallStr::EMPTY, ) } - - // following are not implemented because gt, lt comparison of series don't make sense - fn gt(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } - - fn gt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } } impl Not for &BooleanChunked { diff --git a/crates/polars-core/src/chunked_array/comparison/scalar.rs b/crates/polars-core/src/chunked_array/comparison/scalar.rs index 1c632299c1e4..4649696b41dd 100644 --- a/crates/polars-core/src/chunked_array/comparison/scalar.rs +++ b/crates/polars-core/src/chunked_array/comparison/scalar.rs @@ -61,13 +61,14 @@ where ca } -impl ChunkCompare for ChunkedArray +impl ChunkCompareEq for ChunkedArray where T: PolarsNumericType, Rhs: ToPrimitive, T::Array: TotalOrdKernel + TotalEqKernel, { type Item = BooleanChunked; + fn equal(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); let fa = Some(|x: T::Native| x.tot_ge(&rhs)); @@ -111,6 +112,15 @@ where }) } } +} + +impl ChunkCompareIneq for ChunkedArray +where + T: PolarsNumericType, + Rhs: ToPrimitive, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; fn gt(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); @@ -157,7 +167,7 @@ where } } -impl ChunkCompare<&[u8]> for BinaryChunked { +impl ChunkCompareEq<&[u8]> for BinaryChunked { type Item = BooleanChunked; fn equal(&self, rhs: &[u8]) -> BooleanChunked { @@ -175,6 +185,10 @@ impl ChunkCompare<&[u8]> for BinaryChunked { fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked { arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into()) } +} + +impl ChunkCompareIneq<&[u8]> for BinaryChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &[u8]) -> BooleanChunked { arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into()) @@ -193,7 +207,7 @@ impl ChunkCompare<&[u8]> for BinaryChunked { } } -impl ChunkCompare<&str> for StringChunked { +impl ChunkCompareEq<&str> for StringChunked { type Item = BooleanChunked; fn equal(&self, rhs: &str) -> BooleanChunked { @@ -211,6 +225,10 @@ impl ChunkCompare<&str> for StringChunked { fn not_equal_missing(&self, rhs: &str) -> BooleanChunked { arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into()) } +} + +impl ChunkCompareIneq<&str> for StringChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &str) -> BooleanChunked { arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into()) diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 8da567d06491..2bc1337e598f 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -38,7 +38,6 @@ pub(crate) mod unique; #[cfg(feature = "zip_with")] pub mod zip; -use polars_utils::no_call_const; #[cfg(feature = "serde-lazy")] use serde::{Deserialize, Serialize}; pub use sort::options::*; @@ -312,7 +311,7 @@ pub trait ChunkVar { /// df.filter(&mask) /// } /// ``` -pub trait ChunkCompare { +pub trait ChunkCompareEq { type Item; /// Check for equality. @@ -326,30 +325,24 @@ pub trait ChunkCompare { /// Check for inequality where `None == None`. fn not_equal_missing(&self, rhs: Rhs) -> Self::Item; +} + +/// Compare [`Series`] and [`ChunkedArray`]'s using inequality operators (`<`, `>=`, etc.) and get +/// a `boolean` mask that can be used to filter rows. +pub trait ChunkCompareIneq { + type Item; /// Greater than comparison. - #[allow(unused_variables)] - fn gt(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn gt(&self, rhs: Rhs) -> Self::Item; /// Greater than or equal comparison. - #[allow(unused_variables)] - fn gt_eq(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn gt_eq(&self, rhs: Rhs) -> Self::Item; /// Less than comparison. - #[allow(unused_variables)] - fn lt(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn lt(&self, rhs: Rhs) -> Self::Item; /// Less than or equal comparison - #[allow(unused_variables)] - fn lt_eq(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn lt_eq(&self, rhs: Rhs) -> Self::Item; } /// Get unique values in a `ChunkedArray` diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index b645088b4d68..b073700867af 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -87,7 +87,8 @@ where T: PolarsNumericType, T::Native: TotalHash + TotalEq + ToTotalOrd, ::TotalOrdItem: Hash + Eq + Ord, - ChunkedArray: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, + ChunkedArray: + IntoSeries + for<'a> ChunkCompareEq<&'a ChunkedArray, Item = BooleanChunked>, { fn unique(&self) -> PolarsResult { // prevent stackoverflow repeated sorted.unique call diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 78c36db57f78..cb88a9946ca4 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -992,61 +992,65 @@ impl Column { } } -impl ChunkCompare<&Column> for Column { +impl ChunkCompareEq<&Column> for Column { type Item = PolarsResult; /// Create a boolean mask by checking for equality. #[inline] - fn equal(&self, rhs: &Column) -> PolarsResult { + fn equal(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .equal(rhs.as_materialized_series()) } /// Create a boolean mask by checking for equality. #[inline] - fn equal_missing(&self, rhs: &Column) -> PolarsResult { + fn equal_missing(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .equal_missing(rhs.as_materialized_series()) } /// Create a boolean mask by checking for inequality. #[inline] - fn not_equal(&self, rhs: &Column) -> PolarsResult { + fn not_equal(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .not_equal(rhs.as_materialized_series()) } /// Create a boolean mask by checking for inequality. #[inline] - fn not_equal_missing(&self, rhs: &Column) -> PolarsResult { + fn not_equal_missing(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .not_equal_missing(rhs.as_materialized_series()) } +} + +impl ChunkCompareIneq<&Column> for Column { + type Item = PolarsResult; /// Create a boolean mask by checking if self > rhs. #[inline] - fn gt(&self, rhs: &Column) -> PolarsResult { + fn gt(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .gt(rhs.as_materialized_series()) } /// Create a boolean mask by checking if self >= rhs. #[inline] - fn gt_eq(&self, rhs: &Column) -> PolarsResult { + fn gt_eq(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .gt_eq(rhs.as_materialized_series()) } /// Create a boolean mask by checking if self < rhs. #[inline] - fn lt(&self, rhs: &Column) -> PolarsResult { + fn lt(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .lt(rhs.as_materialized_series()) } /// Create a boolean mask by checking if self <= rhs. #[inline] - fn lt_eq(&self, rhs: &Column) -> PolarsResult { + fn lt_eq(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .lt_eq(rhs.as_materialized_series()) } diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index bea981db89f1..228221c076aa 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -4,8 +4,8 @@ use crate::prelude::*; use crate::series::arithmetic::coerce_lhs_rhs; use crate::series::nulls::replace_non_null; -macro_rules! impl_compare { - ($self:expr, $rhs:expr, $method:ident, $struct_function:expr) => {{ +macro_rules! impl_eq_compare { + ($self:expr, $rhs:expr, $method:ident) => {{ use DataType::*; let (lhs, rhs) = ($self, $rhs); validate_types(lhs.dtype(), rhs.dtype())?; @@ -70,14 +70,7 @@ macro_rules! impl_compare { #[cfg(feature = "dtype-array")] Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()), #[cfg(feature = "dtype-struct")] - Struct(_) => { - let lhs = lhs - .struct_() - .unwrap(); - let rhs = rhs.struct_().unwrap(); - - $struct_function(lhs, rhs)? - }, + Struct(_) => lhs.struct_().unwrap().$method(rhs.struct_().unwrap()), #[cfg(feature = "dtype-decimal")] Decimal(_, s1) => { let DataType::Decimal(_, s2) = rhs.dtype() else { @@ -96,14 +89,108 @@ macro_rules! impl_compare { }}; } -#[cfg(feature = "dtype-struct")] -fn raise_struct(_a: &StructChunked, _b: &StructChunked) -> PolarsResult { - polars_bail!(InvalidOperation: "order comparison not support for struct dtype") +macro_rules! bail_invalid_ineq { + ($lhs:expr, $rhs:expr, $op:literal) => { + polars_bail!( + InvalidOperation: "cannot perform '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}", + $op, + $lhs.name(), $lhs.dtype(), + $rhs.name(), $rhs.dtype(), + ) + }; } -#[cfg(not(feature = "dtype-struct"))] -fn raise_struct(_a: &(), _b: &()) -> PolarsResult { - unimplemented!() +macro_rules! impl_ineq_compare { + ($self:expr, $rhs:expr, $method:ident, $op:literal) => {{ + use DataType::*; + let (lhs, rhs) = ($self, $rhs); + validate_types(lhs.dtype(), rhs.dtype())?; + + polars_ensure!( + lhs.len() == rhs.len() || + + // Broadcast + lhs.len() == 1 || + rhs.len() == 1, + ShapeMismatch: + "could not perform '{}' comparison between series '{}' of length: {} and series '{}' of length: {}, because they have different lengths", + $op, + lhs.name(), lhs.len(), + rhs.name(), rhs.len() + ); + + #[cfg(feature = "dtype-categorical")] + match (lhs.dtype(), rhs.dtype()) { + (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.categorical().unwrap())? + .with_name(lhs.name().clone())); + }, + (Categorical(_, _) | Enum(_, _), String) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + (String, Categorical(_, _) | Enum(_, _)) => { + return Ok(rhs + .categorical() + .unwrap() + .$method(lhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + _ => (), + }; + + let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs).map_err(|_| + polars_err!( + SchemaMismatch: "could not evaluate '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}", + $op, + lhs.name(), lhs.dtype(), + rhs.name(), rhs.dtype() + ) + )?; + let lhs = lhs.to_physical_repr(); + let rhs = rhs.to_physical_repr(); + let mut out = match lhs.dtype() { + Null => lhs.null().unwrap().$method(rhs.null().unwrap()), + Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()), + String => lhs.str().unwrap().$method(rhs.str().unwrap()), + Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()), + UInt8 => lhs.u8().unwrap().$method(rhs.u8().unwrap()), + UInt16 => lhs.u16().unwrap().$method(rhs.u16().unwrap()), + UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()), + UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()), + Int8 => lhs.i8().unwrap().$method(rhs.i8().unwrap()), + Int16 => lhs.i16().unwrap().$method(rhs.i16().unwrap()), + Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()), + Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()), + Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()), + Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()), + List(_) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-array")] + Array(_, _) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-struct")] + Struct(_) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-decimal")] + Decimal(_, s1) => { + let DataType::Decimal(_, s2) = rhs.dtype() else { + unreachable!() + }; + let scale = s1.max(s2).unwrap(); + let lhs = lhs.decimal().unwrap().to_scale(scale).unwrap(); + let rhs = rhs.decimal().unwrap().to_scale(scale).unwrap(); + lhs.0.$method(&rhs.0) + }, + + dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()), + }; + out.rename(lhs.name().clone()); + PolarsResult::Ok(out) + }}; } fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> { @@ -124,74 +211,61 @@ fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> { Ok(()) } -impl ChunkCompare<&Series> for Series { +impl ChunkCompareEq<&Series> for Series { type Item = PolarsResult; /// Create a boolean mask by checking for equality. - fn equal(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, equal, |a: &StructChunked, b: &StructChunked| { - PolarsResult::Ok(a.equal(b)) - }) + fn equal(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, equal) } /// Create a boolean mask by checking for equality. - fn equal_missing(&self, rhs: &Series) -> PolarsResult { - impl_compare!( - self, - rhs, - equal_missing, - |a: &StructChunked, b: &StructChunked| PolarsResult::Ok(a.equal_missing(b)) - ) + fn equal_missing(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, equal_missing) } /// Create a boolean mask by checking for inequality. - fn not_equal(&self, rhs: &Series) -> PolarsResult { - impl_compare!( - self, - rhs, - not_equal, - |a: &StructChunked, b: &StructChunked| PolarsResult::Ok(a.not_equal(b)) - ) + fn not_equal(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, not_equal) } /// Create a boolean mask by checking for inequality. - fn not_equal_missing(&self, rhs: &Series) -> PolarsResult { - impl_compare!( - self, - rhs, - not_equal_missing, - |a: &StructChunked, b: &StructChunked| PolarsResult::Ok(a.not_equal_missing(b)) - ) + fn not_equal_missing(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, not_equal_missing) } +} + +impl ChunkCompareIneq<&Series> for Series { + type Item = PolarsResult; /// Create a boolean mask by checking if self > rhs. - fn gt(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, gt, raise_struct) + fn gt(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, gt, ">") } /// Create a boolean mask by checking if self >= rhs. - fn gt_eq(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, gt_eq, raise_struct) + fn gt_eq(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, gt_eq, ">=") } /// Create a boolean mask by checking if self < rhs. - fn lt(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, lt, raise_struct) + fn lt(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, lt, "<") } /// Create a boolean mask by checking if self <= rhs. - fn lt_eq(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, lt_eq, raise_struct) + fn lt_eq(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, lt_eq, "<=") } } -impl ChunkCompare for Series +impl ChunkCompareEq for Series where Rhs: NumericNative, { type Item = PolarsResult; - fn equal(&self, rhs: Rhs) -> PolarsResult { + fn equal(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, equal, rhs)) @@ -203,7 +277,7 @@ where Ok(apply_method_physical_numeric!(&s, equal_missing, rhs)) } - fn not_equal(&self, rhs: Rhs) -> PolarsResult { + fn not_equal(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, not_equal, rhs)) @@ -214,33 +288,40 @@ where let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, not_equal_missing, rhs)) } +} - fn gt(&self, rhs: Rhs) -> PolarsResult { +impl ChunkCompareIneq for Series +where + Rhs: NumericNative, +{ + type Item = PolarsResult; + + fn gt(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, gt, rhs)) } - fn gt_eq(&self, rhs: Rhs) -> PolarsResult { + fn gt_eq(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, gt_eq, rhs)) } - fn lt(&self, rhs: Rhs) -> PolarsResult { + fn lt(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, lt, rhs)) } - fn lt_eq(&self, rhs: Rhs) -> PolarsResult { + fn lt_eq(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, lt_eq, rhs)) } } -impl ChunkCompare<&str> for Series { +impl ChunkCompareEq<&str> for Series { type Item = PolarsResult; fn equal(&self, rhs: &str) -> PolarsResult { @@ -294,8 +375,12 @@ impl ChunkCompare<&str> for Series { _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)), } } +} + +impl ChunkCompareIneq<&str> for Series { + type Item = PolarsResult; - fn gt(&self, rhs: &str) -> PolarsResult { + fn gt(&self, rhs: &str) -> Self::Item { validate_types(self.dtype(), &DataType::String)?; match self.dtype() { DataType::String => Ok(self.str().unwrap().gt(rhs)), diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index ce9bcffba2f0..a4b40907ad88 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -1,5 +1,5 @@ //! Type agnostic columnar data structure. -pub use crate::prelude::ChunkCompare; +pub use crate::prelude::ChunkCompareEq; use crate::prelude::*; use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 0eeb8555071b..803e0801e636 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -614,12 +614,12 @@ impl ApplyExpr { if max.get(0).unwrap() == min.get(0).unwrap() { let one_equals = - |value: &Series| Some(ChunkCompare::equal(input, value).ok()?.any()); + |value: &Series| Some(ChunkCompareEq::equal(input, value).ok()?.any()); return one_equals(min); } - let smaller = ChunkCompare::lt(input, min).ok()?; - let bigger = ChunkCompare::gt(input, max).ok()?; + let smaller = ChunkCompareIneq::lt(input, min).ok()?; + let bigger = ChunkCompareIneq::gt(input, max).ok()?; Some(!(smaller | bigger).all()) }; @@ -662,7 +662,7 @@ impl ApplyExpr { // don't read the row_group anyways as // the condition will evaluate to false. // e.g. in_between(10, 5) - if ChunkCompare::gt(&left, &right).ok()?.all() { + if ChunkCompareIneq::gt(&left, &right).ok()?.all() { return Some(false); } @@ -674,15 +674,15 @@ impl ApplyExpr { }; // check the right limit of the interval. // if the end is open, we should be stricter (lt_eq instead of lt). - if right_open && ChunkCompare::lt_eq(&right, min).ok()?.all() - || !right_open && ChunkCompare::lt(&right, min).ok()?.all() + if right_open && ChunkCompareIneq::lt_eq(&right, min).ok()?.all() + || !right_open && ChunkCompareIneq::lt(&right, min).ok()?.all() { return Some(false); } // we couldn't conclude anything using the right limit, // check the left limit of the interval - if left_open && ChunkCompare::gt_eq(&left, max).ok()?.all() - || !left_open && ChunkCompare::gt(&left, max).ok()?.all() + if left_open && ChunkCompareIneq::gt_eq(&left, max).ok()?.all() + || !left_open && ChunkCompareIneq::gt(&left, max).ok()?.all() { return Some(false); } diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index 179f4524aa7e..d0b00bf2ddac 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -55,12 +55,12 @@ fn apply_operator_owned(left: Series, right: Series, op: Operator) -> PolarsResu pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResult { use DataType::*; match op { - Operator::Gt => ChunkCompare::gt(left, right).map(|ca| ca.into_series()), - Operator::GtEq => ChunkCompare::gt_eq(left, right).map(|ca| ca.into_series()), - Operator::Lt => ChunkCompare::lt(left, right).map(|ca| ca.into_series()), - Operator::LtEq => ChunkCompare::lt_eq(left, right).map(|ca| ca.into_series()), - Operator::Eq => ChunkCompare::equal(left, right).map(|ca| ca.into_series()), - Operator::NotEq => ChunkCompare::not_equal(left, right).map(|ca| ca.into_series()), + Operator::Gt => ChunkCompareIneq::gt(left, right).map(|ca| ca.into_series()), + Operator::GtEq => ChunkCompareIneq::gt_eq(left, right).map(|ca| ca.into_series()), + Operator::Lt => ChunkCompareIneq::lt(left, right).map(|ca| ca.into_series()), + Operator::LtEq => ChunkCompareIneq::lt_eq(left, right).map(|ca| ca.into_series()), + Operator::Eq => ChunkCompareEq::equal(left, right).map(|ca| ca.into_series()), + Operator::NotEq => ChunkCompareEq::not_equal(left, right).map(|ca| ca.into_series()), Operator::Plus => left + right, Operator::Minus => left - right, Operator::Multiply => left * right, @@ -283,7 +283,7 @@ mod stats { use super::*; fn apply_operator_stats_eq(min_max: &Series, literal: &Series) -> bool { - use ChunkCompare as C; + use ChunkCompareIneq as C; // Literal is greater than max, don't need to read. if C::gt(literal, min_max).map(|s| s.all()).unwrap_or(false) { return false; @@ -301,7 +301,7 @@ mod stats { if min_max.len() < 2 || min_max.null_count() > 0 { return true; } - use ChunkCompare as C; + use ChunkCompareEq as C; // First check proofs all values are the same (e.g. min/max is the same) // Second check proofs all values are equal, so we can skip as we search @@ -315,7 +315,7 @@ mod stats { } fn apply_operator_stats_rhs_lit(min_max: &Series, literal: &Series, op: Operator) -> bool { - use ChunkCompare as C; + use ChunkCompareIneq as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), Operator::NotEq => apply_operator_stats_neq(min_max, literal), @@ -351,7 +351,7 @@ mod stats { } fn apply_operator_stats_lhs_lit(literal: &Series, min_max: &Series, op: Operator) -> bool { - use ChunkCompare as C; + use ChunkCompareIneq as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), Operator::NotEq => apply_operator_stats_eq(min_max, literal), diff --git a/crates/polars-ops/src/chunked_array/array/count.rs b/crates/polars-ops/src/chunked_array/array/count.rs index ef54e7b70591..466f148463bf 100644 --- a/crates/polars-ops/src/chunked_array/array/count.rs +++ b/crates/polars-ops/src/chunked_array/array/count.rs @@ -11,7 +11,7 @@ pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult::equal_missing(&s, &value).map(|ca| ca.into_series()) + ChunkCompareEq::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) })?; let out = count_boolean_bits(&ca); Ok(out.into_series()) diff --git a/crates/polars-ops/src/chunked_array/list/count.rs b/crates/polars-ops/src/chunked_array/list/count.rs index e54c603f3a25..89fdd71ed5d2 100644 --- a/crates/polars-ops/src/chunked_array/list/count.rs +++ b/crates/polars-ops/src/chunked_array/list/count.rs @@ -45,7 +45,7 @@ pub fn list_count_matches(ca: &ListChunked, value: AnyValue) -> PolarsResult::equal_missing(&s, &value).map(|ca| ca.into_series()) + ChunkCompareEq::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) })?; let out = count_boolean_bits(&ca); Ok(out.into_series()) diff --git a/crates/polars-ops/src/chunked_array/peaks.rs b/crates/polars-ops/src/chunked_array/peaks.rs index 437756a44327..7631a07ac141 100644 --- a/crates/polars-ops/src/chunked_array/peaks.rs +++ b/crates/polars-ops/src/chunked_array/peaks.rs @@ -4,7 +4,7 @@ use polars_core::prelude::*; /// Get a boolean mask of the local maximum peaks. pub fn peak_max(ca: &ChunkedArray) -> BooleanChunked where - ChunkedArray: for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, + ChunkedArray: for<'a> ChunkCompareIneq<&'a ChunkedArray, Item = BooleanChunked>, { let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); @@ -14,7 +14,7 @@ where /// Get a boolean mask of the local minimum peaks. pub fn peak_min(ca: &ChunkedArray) -> BooleanChunked where - ChunkedArray: for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, + ChunkedArray: for<'a> ChunkCompareIneq<&'a ChunkedArray, Item = BooleanChunked>, { let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); From 592a874c2062ccf2df0ae1465a2019d80f375c33 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Fri, 27 Sep 2024 12:41:56 +0200 Subject: [PATCH 2/2] fix link --- crates/polars-core/src/series/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index a4b40907ad88..72cb3b67dc41 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -90,7 +90,8 @@ use crate::POOL; /// .all(|(a, b)| a == *b)) /// ``` /// -/// See all the comparison operators in the [CmpOps trait](crate::chunked_array::ops::ChunkCompare) +/// See all the comparison operators in the [ChunkCompareEq trait](crate::chunked_array::ops::ChunkCompareEq) and +/// [ChunkCompareIneq trait](crate::chunked_array::ops::ChunkCompareIneq). /// /// ## Iterators /// The Series variants contain differently typed [ChunkedArray](crate::chunked_array::ChunkedArray)s.