Skip to content

Commit

Permalink
fix: Fix type inference error in map_elements for List types (#18542)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Sep 4, 2024
1 parent 6d4b79d commit dab8254
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 256 deletions.
9 changes: 4 additions & 5 deletions crates/polars-python/src/map/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@ use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple};
use super::*;
use crate::PyDataFrame;

/// Create iterators for all the Series in the DataFrame.
fn get_iters(df: &DataFrame) -> Vec<SeriesIter> {
df.get_columns().iter().map(|s| s.iter()).collect()
}

fn get_iters_skip(df: &DataFrame, skip: usize) -> Vec<std::iter::Skip<SeriesIter>> {
df.get_columns()
.iter()
.map(|s| s.iter().skip(skip))
.collect()
/// Create iterators for all the Series in the DataFrame, skipping the first `n` rows.
fn get_iters_skip(df: &DataFrame, n: usize) -> Vec<std::iter::Skip<SeriesIter>> {
df.get_columns().iter().map(|s| s.iter().skip(n)).collect()
}

// the return type is Union[PySeries, PyDataFrame] and a boolean indicating if it is a dataframe or not
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-python/src/map/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>(
let py_pyseries = series.getattr(py, "_s").unwrap();
let series = py_pyseries.extract::<PySeries>(py).unwrap().series;

// Empty dtype is incorrect, use AnyValues.
if series.is_empty() {
let dt = series.dtype();

// Null dtype may be incorrect, fall back to AnyValues logic.
if dt.is_nested_null() {
let av = out.extract::<Wrap<AnyValue>>()?;
return applyer
.apply_extract_any_values(py, lambda, null_count, av.0)
.map(|s| s.into());
}

let dt = series.dtype();

// make a new python function that is:
// def new_lambda(lambda: Callable):
// pl.Series(lambda(value))
Expand Down
236 changes: 0 additions & 236 deletions crates/polars-python/src/series/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ use pyo3::Python;
use super::PySeries;
use crate::dataframe::PyDataFrame;
use crate::error::PyPolarsErr;
use crate::map::series::{call_lambda_and_extract, ApplyLambda};
use crate::prelude::*;
use crate::py_modules::POLARS;
use crate::{apply_method_all_arrow_series2, raise_err};

#[pymethods]
impl PySeries {
Expand Down Expand Up @@ -315,240 +313,6 @@ impl PySeries {
self.series.clone().into()
}

#[pyo3(signature = (lambda, output_type, skip_nulls))]
fn apply_lambda(
&self,
lambda: &Bound<PyAny>,
output_type: Option<Wrap<DataType>>,
skip_nulls: bool,
) -> PyResult<PySeries> {
let series = &self.series;

if output_type.is_none() {
polars_warn!(
MapWithoutReturnDtypeWarning,
"Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \
Specify `return_dtype` to silence this warning.")
}

if skip_nulls && (series.null_count() == series.len()) {
if let Some(output_type) = output_type {
return Ok(
Series::full_null(series.name().clone(), series.len(), &output_type.0).into(),
);
}
let msg = "The output type of the 'map_elements' function cannot be determined.\n\
The function was never called because 'skip_nulls=True' and all values are null.\n\
Consider setting 'skip_nulls=False' or setting the 'return_dtype'.";
raise_err!(msg, ComputeError)
}

let output_type = output_type.map(|dt| dt.0);

macro_rules! dispatch_apply {
($self:expr, $method:ident, $($args:expr),*) => {
match $self.dtype() {
#[cfg(feature = "object")]
DataType::Object(_, _) => {
let ca = $self.0.unpack::<ObjectType<ObjectValue>>().unwrap();
ca.$method($($args),*)
},
_ => {
apply_method_all_arrow_series2!(
$self,
$method,
$($args),*
)
}

}
}

}

Python::with_gil(|py| {
if matches!(
self.series.dtype(),
DataType::Datetime(_, _)
| DataType::Date
| DataType::Duration(_)
| DataType::Categorical(_, _)
| DataType::Enum(_, _)
| DataType::Binary
| DataType::Array(_, _)
| DataType::Time
) || !skip_nulls
{
let mut avs = Vec::with_capacity(self.series.len());
let s = self.series.rechunk();
let iter = s.iter().map(|av| match (skip_nulls, av) {
(true, AnyValue::Null) => AnyValue::Null,
(_, av) => {
let input = Wrap(av);
call_lambda_and_extract::<_, Wrap<AnyValue>>(py, lambda, input)
.unwrap()
.0
},
});
avs.extend(iter);
return Ok(Series::new(self.series.name().clone(), &avs).into());
}

let out = match output_type {
Some(DataType::Int8) => {
let ca: Int8Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::Int16) => {
let ca: Int16Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::Int32) => {
let ca: Int32Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::Int64) => {
let ca: Int64Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::UInt8) => {
let ca: UInt8Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::UInt16) => {
let ca: UInt16Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::UInt32) => {
let ca: UInt32Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::UInt64) => {
let ca: UInt64Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::Float32) => {
let ca: Float32Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::Float64) => {
let ca: Float64Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::Boolean) => {
let ca: BooleanChunked = dispatch_apply!(
series,
apply_lambda_with_bool_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
Some(DataType::String) => {
let ca = dispatch_apply!(
series,
apply_lambda_with_string_out_type,
py,
lambda,
0,
None
)?;

ca.into_series()
},
#[cfg(feature = "object")]
Some(DataType::Object(_, _)) => {
let ca = dispatch_apply!(
series,
apply_lambda_with_object_out_type,
py,
lambda,
0,
None
)?;
ca.into_series()
},
None => return dispatch_apply!(series, apply_lambda_unknown, py, lambda),

_ => return dispatch_apply!(series, apply_lambda_unknown, py, lambda),
};

Ok(out.into())
})
}

fn zip_with(&self, mask: &PySeries, other: &PySeries) -> PyResult<Self> {
let mask = mask.series.bool().map_err(PyPolarsErr::from)?;
let s = self
Expand Down
Loading

0 comments on commit dab8254

Please sign in to comment.