Skip to content

Commit

Permalink
refactor(python): Very minor refactor of DataFrame.to_numpy code (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored May 20, 2024
1 parent f5c32f2 commit ec904e6
Showing 1 changed file with 11 additions and 23 deletions.
34 changes: 11 additions & 23 deletions py-polars/src/to_numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use numpy::{
npyffi, Element, IntoPyArray, PyArrayDescr, PyArrayDescrMethods, ToNpyDims, PY_ARRAY_API,
};
use polars_core::prelude::*;
use polars_core::utils::try_get_supertype;
use polars_core::utils::dtypes_to_supertype;
use polars_core::with_match_physical_numeric_polars_type;
use pyo3::intern;
use pyo3::prelude::*;
Expand Down Expand Up @@ -253,6 +253,11 @@ pub(crate) fn reshape_numpy_array(
#[pymethods]
#[allow(clippy::wrong_self_convention)]
impl PyDataFrame {
/// Create a view of the data as a NumPy ndarray.
///
/// WARNING: The resulting view will show the underlying value for nulls,
/// which may be any value. The caller is responsible for handling nulls
/// appropriately.
pub fn to_numpy_view(&self, py: Python) -> Option<PyObject> {
if self.df.is_empty() {
return None;
Expand Down Expand Up @@ -320,31 +325,14 @@ impl PyDataFrame {
})
}

/// Convert this DataFrame to a NumPy ndarray.
pub fn to_numpy(&self, py: Python, order: Wrap<IndexOrder>) -> Option<PyObject> {
let mut st = None;
for s in self.df.iter() {
let dt_i = s.dtype();
match st {
None => st = Some(dt_i.clone()),
Some(ref mut st) => {
*st = try_get_supertype(st, dt_i).ok()?;
},
}
}
let st = st?;
let st = dtypes_to_supertype(self.df.iter().map(|s| s.dtype())).ok()?;

#[rustfmt::skip]
let pyarray = match st {
DataType::UInt8 => self.df.to_ndarray::<UInt8Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::Int8 => self.df.to_ndarray::<Int8Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::UInt16 => self.df.to_ndarray::<UInt16Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::Int16 => self.df.to_ndarray::<Int16Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::UInt32 => self.df.to_ndarray::<UInt32Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::UInt64 => self.df.to_ndarray::<UInt64Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::Int32 => self.df.to_ndarray::<Int32Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::Int64 => self.df.to_ndarray::<Int64Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::Float32 => self.df.to_ndarray::<Float32Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
DataType::Float64 => self.df.to_ndarray::<Float64Type>(order.0).ok()?.into_pyarray_bound(py).into_py(py),
dt if dt.is_numeric() => with_match_physical_numeric_polars_type!(dt, |$T| {
self.df.to_ndarray::<$T>(order.0).ok()?.into_pyarray_bound(py).into_py(py)
}),
_ => return None,
};
Some(pyarray)
Expand Down

0 comments on commit ec904e6

Please sign in to comment.