From 9de4ea4d696bf9e1d26dfcc8c2cf8ea85566677b Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 1 May 2024 14:35:19 +0200 Subject: [PATCH 1/2] fix: Ternary supertype dynamics --- crates/polars-core/src/series/mod.rs | 40 +++++++- .../optimizer/type_coercion/binary.rs | 1 - .../optimizer/type_coercion/mod.rs | 94 +------------------ .../tests/unit/functions/test_when_then.py | 19 ++++ 4 files changed, 57 insertions(+), 97 deletions(-) diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index f6c7b64e471c..bc73d230f9de 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -29,7 +29,9 @@ pub use series_trait::{IsSorted, *}; use crate::chunked_array::Settings; #[cfg(feature = "zip_with")] use crate::series::arithmetic::coerce_lhs_rhs; -use crate::utils::{_split_offsets, handle_casting_failures, split_ca, split_series, Wrap}; +use crate::utils::{ + _split_offsets, handle_casting_failures, materialize_dyn_int, split_ca, split_series, Wrap, +}; use crate::POOL; /// # Series @@ -309,9 +311,39 @@ impl Series { /// Cast `[Series]` to another `[DataType]`. pub fn cast(&self, dtype: &DataType) -> PolarsResult { - // Best leave as is. - if !dtype.is_known() || (dtype.is_primitive() && dtype == self.dtype()) { - return Ok(self.clone()); + match dtype { + DataType::Unknown(kind) => { + return match kind { + // Best leave as is. + UnknownKind::Any => Ok(self.clone()), + UnknownKind::Int(v) => { + if self.dtype().is_integer() { + Ok(self.clone()) + } else { + self.cast(&materialize_dyn_int(*v).dtype()) + } + }, + UnknownKind::Float => { + if self.dtype().is_float() { + Ok(self.clone()) + } else { + self.cast(&DataType::Float64) + } + }, + UnknownKind::Str => { + if self.dtype().is_string() | self.dtype().is_categorical() { + Ok(self.clone()) + } else { + self.cast(&DataType::String) + } + }, + }; + }, + // Best leave as is. + dt if dt.is_primitive() && dt == self.dtype() => { + return Ok(self.clone()); + }, + _ => {}, } let ret = self.0.cast(dtype); let len = self.len(); diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs index 456f021d1e26..f0eb0051b803 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs @@ -240,7 +240,6 @@ pub(super) fn process_binary( right: node_right, })); }, - (Unknown(lhs), Unknown(rhs)) if lhs == rhs => return Ok(None), _ => { unpack!(early_escape(&type_left, &type_right)); }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index c97b89e52613..d38d58b027ef 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -31,54 +31,6 @@ fn modify_supertype( type_left: &DataType, type_right: &DataType, ) -> DataType { - use AExpr::*; - - let dynamic_st_or_unknown = matches!(st, DataType::Unknown(_)); - - match (left, right) { - ( - Literal( - lv_left @ (LiteralValue::Int(_) - | LiteralValue::Float(_) - | LiteralValue::StrCat(_) - | LiteralValue::Null), - ), - Literal( - lv_right @ (LiteralValue::Int(_) - | LiteralValue::Float(_) - | LiteralValue::StrCat(_) - | LiteralValue::Null), - ), - ) => { - let lhs = lv_left.to_any_value().unwrap().dtype(); - let rhs = lv_right.to_any_value().unwrap().dtype(); - st = get_supertype(&lhs, &rhs).unwrap(); - return st; - }, - // Materialize dynamic types - ( - Literal( - lv_left @ (LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_)), - ), - _, - ) if dynamic_st_or_unknown => { - st = lv_left.to_any_value().unwrap().dtype(); - return st; - }, - ( - _, - Literal( - lv_right - @ (LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_)), - ), - ) if dynamic_st_or_unknown => { - st = lv_right.to_any_value().unwrap().dtype(); - return st; - }, - // do nothing - _ => {}, - } - // TODO! This must be removed and dealt properly with dynamic str. use DataType::*; match (type_left, type_right, left, right) { @@ -185,44 +137,9 @@ impl OptimizationRule for TypeCoercionRule { let (falsy, type_false) = unpack!(get_aexpr_and_type(expr_arena, falsy_node, &input_schema)); - match (&type_true, &type_false) { - (DataType::Unknown(lhs), DataType::Unknown(rhs)) => { - match (lhs, rhs) { - (UnknownKind::Any, _) | (_, UnknownKind::Any) => return Ok(None), - // continue - (UnknownKind::Int(_), UnknownKind::Float) - | (UnknownKind::Float, UnknownKind::Int(_)) => {}, - (lhs, rhs) if lhs == rhs => { - let falsy = materialize(falsy); - let truthy = materialize(truthy); - - if falsy.is_none() && truthy.is_none() { - return Ok(None); - } - - let falsy = if let Some(falsy) = falsy { - expr_arena.add(falsy) - } else { - falsy_node - }; - let truthy = if let Some(truthy) = truthy { - expr_arena.add(truthy) - } else { - truthy_node - }; - return Ok(Some(AExpr::Ternary { - truthy, - falsy, - predicate, - })); - }, - _ => {}, - } - }, - (lhs, rhs) if lhs == rhs => return Ok(None), - _ => {}, + if type_true == type_false { + return Ok(None); } - let st = unpack!(get_supertype(&type_true, &type_false)); let st = modify_supertype(st, truthy, falsy, &type_true, &type_false); @@ -612,13 +529,6 @@ fn inline_or_prune_cast( fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> { match (type_self, type_other) { - (DataType::Unknown(lhs), DataType::Unknown(rhs)) => match (lhs, rhs) { - (UnknownKind::Any, _) | (_, UnknownKind::Any) => None, - (UnknownKind::Int(_), UnknownKind::Float) - | (UnknownKind::Float, UnknownKind::Int(_)) => Some(()), - (lhs, rhs) if lhs == rhs => None, - _ => Some(()), - }, (lhs, rhs) if lhs == rhs => None, _ => Some(()), } diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index 7625fece9987..8315b0801597 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -604,3 +604,22 @@ def test_when_then_supertype_15975() -> None: assert df.with_columns( pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a")) ).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]} + + +def test_when_then_supertype_15975_comment() -> None: + df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]}) + + q = df.with_columns( + pl.when(pl.col("foo") == 1) + .then(1) + .when(pl.col("foo") == 2) + .then(4) + .when(pl.col("foo") == 3) + .then(1.5) + .when(pl.col("foo") == 4) + .then(16) + .otherwise(0) + .alias("val") + ) + + assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0] From d9ac8755b022c5fc9996d77f7b3f0a3ca58973f3 Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 1 May 2024 14:51:00 +0200 Subject: [PATCH 2/2] fix test --- .../tests/unit/lazyframe/test_tree_format.py | 54 ------------------- 1 file changed, 54 deletions(-) delete mode 100644 py-polars/tests/unit/lazyframe/test_tree_format.py diff --git a/py-polars/tests/unit/lazyframe/test_tree_format.py b/py-polars/tests/unit/lazyframe/test_tree_format.py deleted file mode 100644 index 7ceb31fa5acc..000000000000 --- a/py-polars/tests/unit/lazyframe/test_tree_format.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -import polars as pl - - -def test_logical_plan_tree_format() -> None: - lf = ( - pl.LazyFrame( - { - "foo": [1, 2, 3], - "bar": [6, 7, 8], - "ham": ["a", "b", "c"], - } - ) - .select(foo=pl.col("foo") + 1, bar=pl.col("bar") + 2) - .select( - threshold=pl.when(pl.col("foo") + pl.col("bar") > 2).then(10).otherwise(0) - ) - ) - - expected = """ - SELECT [.when([([(col("foo")) + (col("bar"))]) > (2)]).then(10).otherwise(0).alias("threshold")] FROM - SELECT [[(col("foo")) + (1)].alias("foo"), [(col("bar")) + (2)].alias("bar")] FROM - DF ["foo", "bar", "ham"]; PROJECT 2/3 COLUMNS; SELECTION: "None" -""" - assert lf.explain().strip() == expected.strip() - - expected = """ - 0 1 2 3 - ┌────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── - │ - │ ╭────────╮ - 0 │ │ SELECT │ - │ ╰───┬┬───╯ - │ ││ - │ │╰─────────────────────────────────────╮ - │ │ │ - │ ╭───────────────────────┴────────────────────────╮ │ - │ │ expression: │ ╭───┴────╮ - │ │ .when([([(col("foo")) + (col("bar"))]) > (2)]) │ │ FROM: │ - 1 │ │ .then(10) │ │ SELECT │ - │ │ .otherwise(0) │ ╰───┬┬───╯ - │ │ .alias("threshold") │ ││ - │ ╰────────────────────────────────────────────────╯ ││ - │ ││ - │ │╰────────────────────────┬───────────────────────────╮ - │ │ │ │ - │ ╭──────────┴───────────╮ ╭──────────┴───────────╮ ╭────────────┴─────────────╮ - │ │ expression: │ │ expression: │ │ FROM: │ - 2 │ │ [(col("foo")) + (1)] │ │ [(col("bar")) + (2)] │ │ DF ["foo", "bar", "ham"] │ - │ │ .alias("foo") │ │ .alias("bar") │ │ PROJECT 2/3 COLUMNS │ - │ ╰──────────────────────╯ ╰──────────────────────╯ ╰──────────────────────────╯ -""" - assert lf.explain(tree_format=True).strip() == expected.strip()