diff --git a/py-polars/polars/internals/lazy_functions.py b/py-polars/polars/internals/lazy_functions.py index 4ef363abac48..7d1ef963dff1 100644 --- a/py-polars/polars/internals/lazy_functions.py +++ b/py-polars/polars/internals/lazy_functions.py @@ -11,6 +11,7 @@ Date, Datetime, Duration, + Int32, Int64, PolarsDataType, SchemaDict, @@ -2791,6 +2792,12 @@ def repeat( if name is None: name = "" dtype = py_type_to_dtype(type(value)) + if ( + dtype == Int64 + and isinstance(value, int) + and -(2**31) <= value <= 2**31 - 1 + ): + dtype = Int32 s = pli.Series._repeat(name, value, n, dtype) # type: ignore[arg-type] return s else: diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 6a1b49f10d3e..3f780110c871 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -154,7 +154,7 @@ fn repeat(value: &PyAny, n_times: PyExpr) -> PyResult { } else if let Ok(int) = value.downcast::() { let val = int.extract::().unwrap(); - if val >= 0 && val <= i32::MAX as i64 || val < 0 && val >= i32::MIN as i64 { + if val >= i32::MIN as i64 && val <= i32::MAX as i64 { Ok(polars_rs::lazy::dsl::repeat(val as i32, n_times.inner).into()) } else { Ok(polars_rs::lazy::dsl::repeat(val, n_times.inner).into()) diff --git a/py-polars/src/series.rs b/py-polars/src/series.rs index 51998dee74b9..720ad09e6595 100644 --- a/py-polars/src/series.rs +++ b/py-polars/src/series.rs @@ -285,6 +285,12 @@ impl PySeries { ca.rename(name); ca.into_inner().into_series().into() } + DataType::Int32 => { + let val = val.extract::().unwrap(); + let mut ca: NoNull = (0..n).map(|_| val).collect_trusted(); + ca.rename(name); + ca.into_inner().into_series().into() + } DataType::Float64 => { let val = val.extract::().unwrap(); let mut ca: NoNull = (0..n).map(|_| val).collect_trusted(); diff --git a/py-polars/tests/unit/test_functions.py b/py-polars/tests/unit/test_functions.py index 1632176a03f7..e745923aa229 100644 --- a/py-polars/tests/unit/test_functions.py +++ b/py-polars/tests/unit/test_functions.py @@ -317,3 +317,41 @@ def test_fill_null_unknown_output_type() -> None: 148.4131591025766, ] } + + +def test_repeat() -> None: + s = pl.select(pl.repeat(2**31 - 1, 3)).to_series() + assert s.dtype == pl.Int32 + assert s.len() == 3 + assert s.to_list() == [2**31 - 1] * 3 + s = pl.select(pl.repeat(-(2**31), 4)).to_series() + assert s.dtype == pl.Int32 + assert s.len() == 4 + assert s.to_list() == [-(2**31)] * 4 + s = pl.select(pl.repeat(2**31, 5)).to_series() + assert s.dtype == pl.Int64 + assert s.len() == 5 + assert s.to_list() == [2**31] * 5 + s = pl.select(pl.repeat(-(2**31) - 1, 3)).to_series() + assert s.dtype == pl.Int64 + assert s.len() == 3 + assert s.to_list() == [-(2**31) - 1] * 3 + s = pl.select(pl.repeat("foo", 2)).to_series() + assert s.dtype == pl.Utf8 + assert s.len() == 2 + assert s.to_list() == ["foo"] * 2 + s = pl.select(pl.repeat(1.0, 5)).to_series() + assert s.dtype == pl.Float64 + assert s.len() == 5 + assert s.to_list() == [1.0] * 5 + s = pl.select(pl.repeat(True, 4)).to_series() + assert s.dtype == pl.Boolean + assert s.len() == 4 + assert s.to_list() == [True] * 4 + s = pl.select(pl.repeat(None, 7)).to_series() + assert s.dtype == pl.Null + assert s.len() == 7 + assert s.to_list() == [None] * 7 + s = pl.select(pl.repeat(0, 0)).to_series() + assert s.dtype == pl.Int32 + assert s.len() == 0 diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 80501954af11..e8b12e8aacec 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -960,22 +960,41 @@ def test_object() -> None: def test_repeat() -> None: - s = pl.repeat(1, 10, eager=True) + s = pl.repeat(2**31 - 1, 3, eager=True) + assert s.dtype == pl.Int32 + assert s.len() == 3 + assert s.to_list() == [2**31 - 1] * 3 + s = pl.repeat(-(2**31), 4, eager=True) + assert s.dtype == pl.Int32 + assert s.len() == 4 + assert s.to_list() == [-(2**31)] * 4 + s = pl.repeat(2**31, 5, eager=True) + assert s.dtype == pl.Int64 + assert s.len() == 5 + assert s.to_list() == [2**31] * 5 + s = pl.repeat(-(2**31) - 1, 3, eager=True) assert s.dtype == pl.Int64 - assert s.len() == 10 - s = pl.repeat("foo", 10, eager=True) + assert s.len() == 3 + assert s.to_list() == [-(2**31) - 1] * 3 + s = pl.repeat("foo", 2, eager=True) assert s.dtype == pl.Utf8 - assert s.len() == 10 + assert s.len() == 2 + assert s.to_list() == ["foo"] * 2 s = pl.repeat(1.0, 5, eager=True) assert s.dtype == pl.Float64 assert s.len() == 5 - assert s.to_list() == [1.0, 1.0, 1.0, 1.0, 1.0] - s = pl.repeat(True, 5, eager=True) + assert s.to_list() == [1.0] * 5 + s = pl.repeat(True, 4, eager=True) assert s.dtype == pl.Boolean - assert s.len() == 5 + assert s.len() == 4 + assert s.to_list() == [True] * 4 s = pl.repeat(None, 7, eager=True) assert s.dtype == pl.Null assert s.len() == 7 + assert s.to_list() == [None] * 7 + s = pl.repeat(0, 0, eager=True) + assert s.dtype == pl.Int32 + assert s.len() == 0 def test_shape() -> None: