diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index da79a0cf57df..d1f39c3b135a 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4525,7 +4525,7 @@ def rename( """ return self.lazy().rename(mapping, strict=strict).collect(_eager=True) - def insert_column(self, index: int, column: Series) -> DataFrame: + def insert_column(self, index: int, column: IntoExprColumn) -> DataFrame: """ Insert a Series at a certain column index. @@ -4536,7 +4536,7 @@ def insert_column(self, index: int, column: Series) -> DataFrame: index Index at which to insert the new `Series` column. column - `Series` to insert. + `Series` or expression to insert. Examples -------- @@ -4575,9 +4575,27 @@ def insert_column(self, index: int, column: Series) -> DataFrame: │ 4 ┆ 13.0 ┆ true ┆ 0.0 │ └─────┴──────┴───────┴──────┘ """ - if index < 0: + if (original_index := index) < 0: index = len(self.columns) + index - self._df.insert_column(index, column._s) + if index < 0: + msg = f"column index {original_index} is out of range (frame has {len(self.columns)} columns)" + raise IndexError(msg) + elif index > len(self.columns): + msg = f"column index {original_index} is out of range (frame has {len(self.columns)} columns)" + raise IndexError(msg) + + if isinstance(column, pl.Series): + self._df.insert_column(index, column._s) + else: + if isinstance(column, str): + column = F.col(column) + if isinstance(column, pl.Expr): + cols = self.columns + cols.insert(index, column) # type: ignore[arg-type] + self._df = self.select(cols)._df + else: + msg = f"column must be a Series or Expr, got {column!r} (type={type(column)})" + raise TypeError(msg) return self def filter( diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 2389bef36c88..015b64ce6303 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -458,6 +458,7 @@ def test_assignment() -> None: def test_insert_column() -> None: + # insert series df = ( pl.DataFrame({"z": [3, 4, 5]}) .insert_column(0, pl.Series("x", [1, 2, 3])) @@ -466,6 +467,39 @@ def test_insert_column() -> None: expected_df = pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]}) assert_frame_equal(expected_df, df) + # insert expressions + df = pl.DataFrame( + { + "id": ["xx", "yy", "zz"], + "v1": [5, 4, 6], + "v2": [7, 3, 3], + } + ) + df.insert_column(3, (pl.col("v1") * pl.col("v2")).alias("v3")) + df.insert_column(1, (pl.col("v2") - pl.col("v1")).alias("v0")) + + expected = pl.DataFrame( + { + "id": ["xx", "yy", "zz"], + "v0": [2, -1, -3], + "v1": [5, 4, 6], + "v2": [7, 3, 3], + "v3": [35, 12, 18], + } + ) + assert_frame_equal(df, expected) + + # check that we raise suitable index errors + for idx, column in ( + (10, pl.col("v1").sqrt().alias("v1_sqrt")), + (-10, pl.Series("foo", [1, 2, 3])), + ): + with pytest.raises( + IndexError, + match=rf"column index {idx} is out of range \(frame has 5 columns\)", + ): + df.insert_column(idx, column) + def test_replace_column() -> None: df = (