From dab8254dd7fcf7ca3882e799a081c1a1cf45607a Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 4 Sep 2024 12:51:47 +0200 Subject: [PATCH] fix: Fix type inference error in `map_elements` for List types (#18542) --- crates/polars-python/src/map/dataframe.rs | 9 +- crates/polars-python/src/map/series.rs | 8 +- crates/polars-python/src/series/general.rs | 236 ----------------- crates/polars-python/src/series/map.rs | 245 ++++++++++++++++++ crates/polars-python/src/series/mod.rs | 2 + py-polars/polars/series/series.py | 4 +- .../unit/operations/map/test_map_elements.py | 27 +- 7 files changed, 275 insertions(+), 256 deletions(-) create mode 100644 crates/polars-python/src/series/map.rs diff --git a/crates/polars-python/src/map/dataframe.rs b/crates/polars-python/src/map/dataframe.rs index c91353dfff8d..5be2216b0898 100644 --- a/crates/polars-python/src/map/dataframe.rs +++ b/crates/polars-python/src/map/dataframe.rs @@ -8,15 +8,14 @@ use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple}; use super::*; use crate::PyDataFrame; +/// Create iterators for all the Series in the DataFrame. fn get_iters(df: &DataFrame) -> Vec { df.get_columns().iter().map(|s| s.iter()).collect() } -fn get_iters_skip(df: &DataFrame, skip: usize) -> Vec> { - df.get_columns() - .iter() - .map(|s| s.iter().skip(skip)) - .collect() +/// Create iterators for all the Series in the DataFrame, skipping the first `n` rows. +fn get_iters_skip(df: &DataFrame, n: usize) -> Vec> { + df.get_columns().iter().map(|s| s.iter().skip(n)).collect() } // the return type is Union[PySeries, PyDataFrame] and a boolean indicating if it is a dataframe or not diff --git a/crates/polars-python/src/map/series.rs b/crates/polars-python/src/map/series.rs index 3afebc16f046..7940be418aa8 100644 --- a/crates/polars-python/src/map/series.rs +++ b/crates/polars-python/src/map/series.rs @@ -46,16 +46,16 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>( let py_pyseries = series.getattr(py, "_s").unwrap(); let series = py_pyseries.extract::(py).unwrap().series; - // Empty dtype is incorrect, use AnyValues. - if series.is_empty() { + let dt = series.dtype(); + + // Null dtype may be incorrect, fall back to AnyValues logic. + if dt.is_nested_null() { let av = out.extract::>()?; return applyer .apply_extract_any_values(py, lambda, null_count, av.0) .map(|s| s.into()); } - let dt = series.dtype(); - // make a new python function that is: // def new_lambda(lambda: Callable): // pl.Series(lambda(value)) diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index 9d2c0a1de98e..359f39df6291 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -11,10 +11,8 @@ use pyo3::Python; use super::PySeries; use crate::dataframe::PyDataFrame; use crate::error::PyPolarsErr; -use crate::map::series::{call_lambda_and_extract, ApplyLambda}; use crate::prelude::*; use crate::py_modules::POLARS; -use crate::{apply_method_all_arrow_series2, raise_err}; #[pymethods] impl PySeries { @@ -315,240 +313,6 @@ impl PySeries { self.series.clone().into() } - #[pyo3(signature = (lambda, output_type, skip_nulls))] - fn apply_lambda( - &self, - lambda: &Bound, - output_type: Option>, - skip_nulls: bool, - ) -> PyResult { - let series = &self.series; - - if output_type.is_none() { - polars_warn!( - MapWithoutReturnDtypeWarning, - "Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \ - Specify `return_dtype` to silence this warning.") - } - - if skip_nulls && (series.null_count() == series.len()) { - if let Some(output_type) = output_type { - return Ok( - Series::full_null(series.name().clone(), series.len(), &output_type.0).into(), - ); - } - let msg = "The output type of the 'map_elements' function cannot be determined.\n\ - The function was never called because 'skip_nulls=True' and all values are null.\n\ - Consider setting 'skip_nulls=False' or setting the 'return_dtype'."; - raise_err!(msg, ComputeError) - } - - let output_type = output_type.map(|dt| dt.0); - - macro_rules! dispatch_apply { - ($self:expr, $method:ident, $($args:expr),*) => { - match $self.dtype() { - #[cfg(feature = "object")] - DataType::Object(_, _) => { - let ca = $self.0.unpack::>().unwrap(); - ca.$method($($args),*) - }, - _ => { - apply_method_all_arrow_series2!( - $self, - $method, - $($args),* - ) - } - - } - } - - } - - Python::with_gil(|py| { - if matches!( - self.series.dtype(), - DataType::Datetime(_, _) - | DataType::Date - | DataType::Duration(_) - | DataType::Categorical(_, _) - | DataType::Enum(_, _) - | DataType::Binary - | DataType::Array(_, _) - | DataType::Time - ) || !skip_nulls - { - let mut avs = Vec::with_capacity(self.series.len()); - let s = self.series.rechunk(); - let iter = s.iter().map(|av| match (skip_nulls, av) { - (true, AnyValue::Null) => AnyValue::Null, - (_, av) => { - let input = Wrap(av); - call_lambda_and_extract::<_, Wrap>(py, lambda, input) - .unwrap() - .0 - }, - }); - avs.extend(iter); - return Ok(Series::new(self.series.name().clone(), &avs).into()); - } - - let out = match output_type { - Some(DataType::Int8) => { - let ca: Int8Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Int16) => { - let ca: Int16Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Int32) => { - let ca: Int32Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Int64) => { - let ca: Int64Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt8) => { - let ca: UInt8Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt16) => { - let ca: UInt16Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt32) => { - let ca: UInt32Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt64) => { - let ca: UInt64Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Float32) => { - let ca: Float32Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Float64) => { - let ca: Float64Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Boolean) => { - let ca: BooleanChunked = dispatch_apply!( - series, - apply_lambda_with_bool_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::String) => { - let ca = dispatch_apply!( - series, - apply_lambda_with_string_out_type, - py, - lambda, - 0, - None - )?; - - ca.into_series() - }, - #[cfg(feature = "object")] - Some(DataType::Object(_, _)) => { - let ca = dispatch_apply!( - series, - apply_lambda_with_object_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - None => return dispatch_apply!(series, apply_lambda_unknown, py, lambda), - - _ => return dispatch_apply!(series, apply_lambda_unknown, py, lambda), - }; - - Ok(out.into()) - }) - } - fn zip_with(&self, mask: &PySeries, other: &PySeries) -> PyResult { let mask = mask.series.bool().map_err(PyPolarsErr::from)?; let s = self diff --git a/crates/polars-python/src/series/map.rs b/crates/polars-python/src/series/map.rs new file mode 100644 index 000000000000..967110019c44 --- /dev/null +++ b/crates/polars-python/src/series/map.rs @@ -0,0 +1,245 @@ +use pyo3::prelude::*; +use pyo3::Python; + +use super::PySeries; +use crate::error::PyPolarsErr; +use crate::map::series::{call_lambda_and_extract, ApplyLambda}; +use crate::prelude::*; +use crate::{apply_method_all_arrow_series2, raise_err}; + +#[pymethods] +impl PySeries { + #[pyo3(signature = (function, return_dtype, skip_nulls))] + fn map_elements( + &self, + function: &Bound, + return_dtype: Option>, + skip_nulls: bool, + ) -> PyResult { + let series = &self.series; + + if return_dtype.is_none() { + polars_warn!( + MapWithoutReturnDtypeWarning, + "Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \ + Specify `return_dtype` to silence this warning.") + } + + if skip_nulls && (series.null_count() == series.len()) { + if let Some(return_dtype) = return_dtype { + return Ok( + Series::full_null(series.name().clone(), series.len(), &return_dtype.0).into(), + ); + } + let msg = "The output type of the 'map_elements' function cannot be determined.\n\ + The function was never called because 'skip_nulls=True' and all values are null.\n\ + Consider setting 'skip_nulls=False' or setting the 'return_dtype'."; + raise_err!(msg, ComputeError) + } + + let return_dtype = return_dtype.map(|dt| dt.0); + + macro_rules! dispatch_apply { + ($self:expr, $method:ident, $($args:expr),*) => { + match $self.dtype() { + #[cfg(feature = "object")] + DataType::Object(_, _) => { + let ca = $self.0.unpack::>().unwrap(); + ca.$method($($args),*) + }, + _ => { + apply_method_all_arrow_series2!( + $self, + $method, + $($args),* + ) + } + + } + } + + } + + Python::with_gil(|py| { + if matches!( + self.series.dtype(), + DataType::Datetime(_, _) + | DataType::Date + | DataType::Duration(_) + | DataType::Categorical(_, _) + | DataType::Enum(_, _) + | DataType::Binary + | DataType::Array(_, _) + | DataType::Time + ) || !skip_nulls + { + let mut avs = Vec::with_capacity(self.series.len()); + let s = self.series.rechunk(); + let iter = s.iter().map(|av| match (skip_nulls, av) { + (true, AnyValue::Null) => AnyValue::Null, + (_, av) => { + let input = Wrap(av); + call_lambda_and_extract::<_, Wrap>(py, function, input) + .unwrap() + .0 + }, + }); + avs.extend(iter); + return Ok(Series::new(self.series.name().clone(), &avs).into()); + } + + let out = match return_dtype { + Some(DataType::Int8) => { + let ca: Int8Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Int16) => { + let ca: Int16Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Int32) => { + let ca: Int32Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Int64) => { + let ca: Int64Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt8) => { + let ca: UInt8Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt16) => { + let ca: UInt16Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt32) => { + let ca: UInt32Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt64) => { + let ca: UInt64Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Float32) => { + let ca: Float32Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Float64) => { + let ca: Float64Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Boolean) => { + let ca: BooleanChunked = dispatch_apply!( + series, + apply_lambda_with_bool_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::String) => { + let ca = dispatch_apply!( + series, + apply_lambda_with_string_out_type, + py, + function, + 0, + None + )?; + + ca.into_series() + }, + #[cfg(feature = "object")] + Some(DataType::Object(_, _)) => { + let ca = dispatch_apply!( + series, + apply_lambda_with_object_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + None => return dispatch_apply!(series, apply_lambda_unknown, py, function), + + _ => return dispatch_apply!(series, apply_lambda_unknown, py, function), + }; + + Ok(out.into()) + }) + } +} diff --git a/crates/polars-python/src/series/mod.rs b/crates/polars-python/src/series/mod.rs index a5d8a8fc77ea..1b4542b06c5a 100644 --- a/crates/polars-python/src/series/mod.rs +++ b/crates/polars-python/src/series/mod.rs @@ -17,6 +17,8 @@ mod general; #[cfg(feature = "pymethods")] mod import; #[cfg(feature = "pymethods")] +mod map; +#[cfg(feature = "pymethods")] mod numpy_ufunc; #[cfg(feature = "pymethods")] mod scatter; diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 27492285ce59..e2a4cb936e47 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -5336,7 +5336,9 @@ def map_elements( warn_on_inefficient_map(function, columns=[self.name], map_target="series") return self._from_pyseries( - self._s.apply_lambda(function, pl_return_dtype, skip_nulls) + self._s.map_elements( + function, return_dtype=pl_return_dtype, skip_nulls=skip_nulls + ) ) def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> Series: diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py index 7edc155e223f..b2affd4fdbd0 100644 --- a/py-polars/tests/unit/operations/map/test_map_elements.py +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -10,6 +10,10 @@ from polars.exceptions import PolarsInefficientMapWarning from polars.testing import assert_frame_equal, assert_series_equal +pytestmark = pytest.mark.filterwarnings( + "ignore::polars.exceptions.PolarsInefficientMapWarning" +) + def test_map_elements_infer_list() -> None: df = pl.DataFrame( @@ -180,17 +184,13 @@ def test_map_elements_skip_nulls() -> None: some_map = {None: "a", 1: "b"} s = pl.Series([None, 1]) - with pytest.warns( - PolarsInefficientMapWarning, - match=r"(?s)Replace this expression.*s\.map_elements\(lambda x:", - ): - assert s.map_elements( - lambda x: some_map[x], return_dtype=pl.String - ).to_list() == [None, "b"] + assert s.map_elements( + lambda x: some_map[x], return_dtype=pl.String, skip_nulls=True + ).to_list() == [None, "b"] - assert s.map_elements( - lambda x: some_map[x], return_dtype=pl.String, skip_nulls=False - ).to_list() == ["a", "b"] + assert s.map_elements( + lambda x: some_map[x], return_dtype=pl.String, skip_nulls=False + ).to_list() == ["a", "b"] def test_map_elements_object_dtypes() -> None: @@ -364,3 +364,10 @@ def test_unknown_map_elements() -> None: "Flour": [10.0, 100.0, 100.0, 20.0], } assert q.collect_schema().dtypes() == [pl.Int64, pl.Unknown] + + +def test_map_elements_list_dtype_18472() -> None: + s = pl.Series([[None], ["abc ", None]]) + result = s.map_elements(lambda s: [i.strip() if i else None for i in s]) + expected = pl.Series([[None], ["abc", None]]) + assert_series_equal(result, expected)