diff --git a/python/cudf/cudf/core/_internals/where.py b/python/cudf/cudf/core/_internals/where.py index 87dc1d8e01f..4da7cd6bbd7 100644 --- a/python/cudf/cudf/core/_internals/where.py +++ b/python/cudf/cudf/core/_internals/where.py @@ -27,7 +27,9 @@ def _normalize_scalars(col: ColumnBase, other: ScalarLike) -> ScalarLike: f"{type(other).__name__} to {col.dtype.name}" ) - return cudf.Scalar(other, dtype=col.dtype if other is None else None) + return cudf.Scalar( + other, dtype=col.dtype if other in {None, cudf.NA} else None + ) def _check_and_cast_columns_with_other( @@ -234,9 +236,15 @@ def where( if isinstance(frame, DataFrame): if hasattr(cond, "__cuda_array_interface__"): - cond = DataFrame( - cond, columns=frame._column_names, index=frame.index - ) + if isinstance(cond, Series): + cond = DataFrame( + {name: cond for name in frame._column_names}, + index=frame.index, + ) + else: + cond = DataFrame( + cond, columns=frame._column_names, index=frame.index + ) elif ( hasattr(cond, "__array_interface__") and cond.__array_interface__["shape"] != frame.shape diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 8744238a062..14176fd932d 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -8731,3 +8731,23 @@ def test_frame_series_where(): expected = gdf.where(gdf.notna(), gdf.mean()) actual = pdf.where(pdf.notna(), pdf.mean(), axis=1) assert_eq(expected, actual) + + +@pytest.mark.parametrize( + "data", [{"a": [1, 2, 3], "b": [1, 1, 0]}], +) +def test_frame_series_where_other(data): + gdf = cudf.DataFrame(data) + pdf = gdf.to_pandas() + + expected = gdf.where(gdf["b"] == 1, cudf.NA) + actual = pdf.where(pdf["b"] == 1, pd.NA) + assert_eq( + actual.fillna(-1).values, + expected.fillna(-1).values, + check_dtype=False, + ) + + expected = gdf.where(gdf["b"] == 1, 0) + actual = pdf.where(pdf["b"] == 1, 0) + assert_eq(expected, actual) diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index 46bd1b449c4..829a1545365 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -581,6 +581,8 @@ def _can_cast(from_dtype, to_dtype): `np.can_cast` but with some special handling around cudf specific dtypes. """ + if from_dtype in {None, cudf.NA}: + return True if isinstance(from_dtype, type): from_dtype = np.dtype(from_dtype) if isinstance(to_dtype, type):