Skip to content

Commit

Permalink
fix(python): Fix dtype parameter in pandas_to_pyseries function (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
luke396 authored Apr 29, 2024
1 parent 3bf32f0 commit 2805eca
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,17 @@ def to_series_chunk(values: list[Any], dtype: PolarsDataType | None) -> Series:
def pandas_to_pyseries(
name: str,
values: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex,
dtype: PolarsDataType | None = None,
*,
nan_to_null: bool = True,
) -> PySeries:
"""Construct a PySeries from a pandas Series or DatetimeIndex."""
if not name and values.name is not None:
name = str(values.name)
if is_simple_numpy_backed_pandas_series(values):
return pl.Series(name, values.to_numpy(), nan_to_null=nan_to_null)._s
return pl.Series(
name, values.to_numpy(), dtype=dtype, nan_to_null=nan_to_null
)._s
if not _PYARROW_AVAILABLE:
msg = (
"pyarrow is required for converting a pandas series to Polars, "
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(
elif _check_for_pandas(values) and isinstance(
values, (pd.Series, pd.Index, pd.DatetimeIndex)
):
self._s = pandas_to_pyseries(name, values)
self._s = pandas_to_pyseries(name, values, dtype=dtype)

elif _is_generator(values):
self._s = iterable_to_pyseries(name, values, dtype=dtype, strict=strict)
Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,3 +2337,8 @@ def test_search_sorted(

multiple_s = s.search_sorted(multiple)
assert_series_equal(multiple_s, pl.Series(multiple_expected, dtype=pl.UInt32))


def test_series_from_pandas_with_dtype() -> None:
s = pl.Series("foo", pd.Series([1, 2, 3]), pl.Float32)
assert_series_equal(s, pl.Series("foo", [1, 2, 3], dtype=pl.Float32))

0 comments on commit 2805eca

Please sign in to comment.