Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Divide ChunkCompare into Eq and Ineq variants #18963

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions crates/polars-core/src/chunked_array/comparison/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ where
}
}

impl ChunkCompare<&CategoricalChunked> for CategoricalChunked {
impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn equal(&self, rhs: &CategoricalChunked) -> Self::Item {
Expand Down Expand Up @@ -134,6 +134,10 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked {
UInt32Chunked::not_equal_missing,
)
}
}

impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn gt(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r)
Expand Down Expand Up @@ -217,7 +221,7 @@ where
}
}

impl ChunkCompare<&StringChunked> for CategoricalChunked {
impl ChunkCompareEq<&StringChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn equal(&self, rhs: &StringChunked) -> Self::Item {
Expand Down Expand Up @@ -265,6 +269,10 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked {
StringChunked::not_equal_missing,
)
}
}

impl ChunkCompareIneq<&StringChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn gt(&self, rhs: &StringChunked) -> Self::Item {
cat_str_compare_helper(
Expand Down Expand Up @@ -376,7 +384,7 @@ where
}
}

impl ChunkCompare<&str> for CategoricalChunked {
impl ChunkCompareEq<&str> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn equal(&self, rhs: &str) -> Self::Item {
Expand Down Expand Up @@ -414,6 +422,10 @@ impl ChunkCompare<&str> for CategoricalChunked {
UInt32Chunked::equal_missing,
)
}
}

impl ChunkCompareIneq<&str> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn gt(&self, rhs: &str) -> Self::Item {
cat_single_str_compare_helper(
Expand Down
75 changes: 33 additions & 42 deletions crates/polars-core/src/chunked_array/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::series::implementations::null::NullChunked;
use crate::series::IsSorted;
use crate::utils::align_chunks_binary;

impl<T> ChunkCompare<&ChunkedArray<T>> for ChunkedArray<T>
impl<T> ChunkCompareEq<&ChunkedArray<T>> for ChunkedArray<T>
where
T: PolarsNumericType,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
Expand Down Expand Up @@ -126,6 +126,14 @@ where
),
}
}
}

impl<T> ChunkCompareIneq<&ChunkedArray<T>> for ChunkedArray<T>
where
T: PolarsNumericType,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
{
type Item = BooleanChunked;

fn lt(&self, rhs: &ChunkedArray<T>) -> BooleanChunked {
// Broadcast.
Expand Down Expand Up @@ -188,7 +196,7 @@ where
}
}

impl ChunkCompare<&NullChunked> for NullChunked {
impl ChunkCompareEq<&NullChunked> for NullChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &NullChunked) -> Self::Item {
Expand All @@ -206,6 +214,10 @@ impl ChunkCompare<&NullChunked> for NullChunked {
fn not_equal_missing(&self, rhs: &NullChunked) -> Self::Item {
BooleanChunked::full(self.name().clone(), false, get_broadcast_length(self, rhs))
}
}

impl ChunkCompareIneq<&NullChunked> for NullChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &NullChunked) -> Self::Item {
BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs))
Expand Down Expand Up @@ -234,7 +246,7 @@ fn get_broadcast_length(lhs: &NullChunked, rhs: &NullChunked) -> usize {
}
}

impl ChunkCompare<&BooleanChunked> for BooleanChunked {
impl ChunkCompareEq<&BooleanChunked> for BooleanChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &BooleanChunked) -> BooleanChunked {
Expand Down Expand Up @@ -348,6 +360,10 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked {
),
}
}
}

impl ChunkCompareIneq<&BooleanChunked> for BooleanChunked {
type Item = BooleanChunked;

fn lt(&self, rhs: &BooleanChunked) -> BooleanChunked {
// Broadcast.
Expand Down Expand Up @@ -410,7 +426,7 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked {
}
}

impl ChunkCompare<&StringChunked> for StringChunked {
impl ChunkCompareEq<&StringChunked> for StringChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &StringChunked) -> BooleanChunked {
Expand All @@ -424,9 +440,14 @@ impl ChunkCompare<&StringChunked> for StringChunked {
fn not_equal(&self, rhs: &StringChunked) -> BooleanChunked {
self.as_binary().not_equal(&rhs.as_binary())
}

fn not_equal_missing(&self, rhs: &StringChunked) -> BooleanChunked {
self.as_binary().not_equal_missing(&rhs.as_binary())
}
}

impl ChunkCompareIneq<&StringChunked> for StringChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &StringChunked) -> BooleanChunked {
self.as_binary().gt(&rhs.as_binary())
Expand All @@ -445,7 +466,7 @@ impl ChunkCompare<&StringChunked> for StringChunked {
}
}

impl ChunkCompare<&BinaryChunked> for BinaryChunked {
impl ChunkCompareEq<&BinaryChunked> for BinaryChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &BinaryChunked) -> BooleanChunked {
Expand Down Expand Up @@ -551,6 +572,10 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked {
),
}
}
}

impl ChunkCompareIneq<&BinaryChunked> for BinaryChunked {
type Item = BooleanChunked;

fn lt(&self, rhs: &BinaryChunked) -> BooleanChunked {
// Broadcast.
Expand Down Expand Up @@ -644,7 +669,7 @@ where
}
}

impl ChunkCompare<&ListChunked> for ListChunked {
impl ChunkCompareEq<&ListChunked> for ListChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &ListChunked) -> BooleanChunked {
let _series_equals = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) {
Expand Down Expand Up @@ -684,23 +709,6 @@ impl ChunkCompare<&ListChunked> for ListChunked {

_list_comparison_helper(self, rhs, _series_not_equal_missing)
}

// The following are not implemented because gt, lt comparison of series don't make sense.
fn gt(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}

fn gt_eq(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}

fn lt(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}

fn lt_eq(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}
}

#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -741,7 +749,7 @@ where
}

#[cfg(feature = "dtype-struct")]
impl ChunkCompare<&StructChunked> for StructChunked {
impl ChunkCompareEq<&StructChunked> for StructChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &StructChunked) -> BooleanChunked {
struct_helper(
Expand Down Expand Up @@ -785,7 +793,7 @@ impl ChunkCompare<&StructChunked> for StructChunked {
}

#[cfg(feature = "dtype-array")]
impl ChunkCompare<&ArrayChunked> for ArrayChunked {
impl ChunkCompareEq<&ArrayChunked> for ArrayChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked {
if self.width() != rhs.width() {
Expand Down Expand Up @@ -834,23 +842,6 @@ impl ChunkCompare<&ArrayChunked> for ArrayChunked {
PlSmallStr::EMPTY,
)
}

// following are not implemented because gt, lt comparison of series don't make sense
fn gt(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}

fn gt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}

fn lt(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}

fn lt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}
}

impl Not for &BooleanChunked {
Expand Down
24 changes: 21 additions & 3 deletions crates/polars-core/src/chunked_array/comparison/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ where
ca
}

impl<T, Rhs> ChunkCompare<Rhs> for ChunkedArray<T>
impl<T, Rhs> ChunkCompareEq<Rhs> for ChunkedArray<T>
where
T: PolarsNumericType,
Rhs: ToPrimitive,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
{
type Item = BooleanChunked;

fn equal(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa = Some(|x: T::Native| x.tot_ge(&rhs));
Expand Down Expand Up @@ -111,6 +112,15 @@ where
})
}
}
}

impl<T, Rhs> ChunkCompareIneq<Rhs> for ChunkedArray<T>
where
T: PolarsNumericType,
Rhs: ToPrimitive,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
{
type Item = BooleanChunked;

fn gt(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
Expand Down Expand Up @@ -157,7 +167,7 @@ where
}
}

impl ChunkCompare<&[u8]> for BinaryChunked {
impl ChunkCompareEq<&[u8]> for BinaryChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &[u8]) -> BooleanChunked {
Expand All @@ -175,6 +185,10 @@ impl ChunkCompare<&[u8]> for BinaryChunked {
fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
}
}

impl ChunkCompareIneq<&[u8]> for BinaryChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &[u8]) -> BooleanChunked {
arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
Expand All @@ -193,7 +207,7 @@ impl ChunkCompare<&[u8]> for BinaryChunked {
}
}

impl ChunkCompare<&str> for StringChunked {
impl ChunkCompareEq<&str> for StringChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &str) -> BooleanChunked {
Expand All @@ -211,6 +225,10 @@ impl ChunkCompare<&str> for StringChunked {
fn not_equal_missing(&self, rhs: &str) -> BooleanChunked {
arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
}
}

impl ChunkCompareIneq<&str> for StringChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &str) -> BooleanChunked {
arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
Expand Down
29 changes: 11 additions & 18 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ pub(crate) mod unique;
#[cfg(feature = "zip_with")]
pub mod zip;

use polars_utils::no_call_const;
#[cfg(feature = "serde-lazy")]
use serde::{Deserialize, Serialize};
pub use sort::options::*;
Expand Down Expand Up @@ -312,7 +311,7 @@ pub trait ChunkVar {
/// df.filter(&mask)
/// }
/// ```
pub trait ChunkCompare<Rhs> {
pub trait ChunkCompareEq<Rhs> {
type Item;

/// Check for equality.
Expand All @@ -326,30 +325,24 @@ pub trait ChunkCompare<Rhs> {

/// Check for inequality where `None == None`.
fn not_equal_missing(&self, rhs: Rhs) -> Self::Item;
}

/// Compare [`Series`] and [`ChunkedArray`]'s using inequality operators (`<`, `>=`, etc.) and get
/// a `boolean` mask that can be used to filter rows.
pub trait ChunkCompareIneq<Rhs> {
type Item;

/// Greater than comparison.
#[allow(unused_variables)]
fn gt(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn gt(&self, rhs: Rhs) -> Self::Item;

/// Greater than or equal comparison.
#[allow(unused_variables)]
fn gt_eq(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn gt_eq(&self, rhs: Rhs) -> Self::Item;

/// Less than comparison.
#[allow(unused_variables)]
fn lt(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn lt(&self, rhs: Rhs) -> Self::Item;

/// Less than or equal comparison
#[allow(unused_variables)]
fn lt_eq(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn lt_eq(&self, rhs: Rhs) -> Self::Item;
}

/// Get unique values in a `ChunkedArray`
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ where
T: PolarsNumericType,
T::Native: TotalHash + TotalEq + ToTotalOrd,
<T::Native as ToTotalOrd>::TotalOrdItem: Hash + Eq + Ord,
ChunkedArray<T>: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray<T>, Item = BooleanChunked>,
ChunkedArray<T>:
IntoSeries + for<'a> ChunkCompareEq<&'a ChunkedArray<T>, Item = BooleanChunked>,
{
fn unique(&self) -> PolarsResult<Self> {
// prevent stackoverflow repeated sorted.unique call
Expand Down
Loading
Loading