Skip to content

Commit

Permalink
New array conversion methods (#9236)
Browse files Browse the repository at this point in the history
This PR adds the `to_numpy` and `to_cupy` methods. `to_numpy` is the preferred method for generating numpy arrays as of pandas 1.0 when older methods like `to_matrix` were removed. This PR deprecates those old methods as well as their gpu counterparts (such as `as_gpu_matrix`) and instead adds `to_cupy` as the preferred method of getting a `__cuda_array_interface__` compliant array view. The new methods are also preferred to the `.values` and `.values_host` accessors, which are not yet deprecated in pandas but are likely to be deprecated at some point due to the ambiguity of their copy semantics.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Michael Wang (https://github.com/isVoid)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

URL: #9236
  • Loading branch information
vyasr authored Oct 1, 2021
1 parent 34b54ca commit e597075
Show file tree
Hide file tree
Showing 50 changed files with 541 additions and 601 deletions.
16 changes: 1 addition & 15 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,22 +798,8 @@ def astype(self, dtype, copy=False):
self.copy(deep=copy)._values.astype(dtype), name=self.name
)

# TODO: This method is deprecated and can be removed.
def to_array(self, fillna=None):
"""Get a dense numpy array for the data.
Parameters
----------
fillna : str or None
Defaults to None, which will skip null values.
If it equals "pandas", null values are filled with NaNs.
Non integral dtype is promoted to np.float64.
Notes
-----
if ``fillna`` is ``None``, null values are skipped. Therefore, the
output size could be smaller.
"""
return self._values.to_array(fillna=fillna)

def to_series(self, index=None, name=None):
Expand Down
10 changes: 5 additions & 5 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def remove_categories(
# ensure all the removals are in the current categories
# list. If not, raise an error to match Pandas behavior
if not removals_mask.all():
vals = removals[~removals_mask].to_array()
vals = removals[~removals_mask].to_numpy()
raise ValueError(f"removals must all be in old categories: {vals}")

new_categories = cats[~cats.isin(removals)]._column
Expand Down Expand Up @@ -1012,11 +1012,11 @@ def _encode(self, value) -> ScalarLike:
return self.categories.find_first_value(value)

def _decode(self, value: int) -> ScalarLike:
if value == self.default_na_value():
if value == self._default_na_value():
return None
return self.categories.element_indexing(value)

def default_na_value(self) -> ScalarLike:
def _default_na_value(self) -> ScalarLike:
return -1

def find_and_replace(
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def fillna(
fill_is_scalar = np.isscalar(fill_value)

if fill_is_scalar:
if fill_value == self.default_na_value():
if fill_value == self._default_na_value():
fill_value = self.codes.dtype.type(fill_value)
else:
try:
Expand Down Expand Up @@ -1578,7 +1578,7 @@ def _create_empty_categorical_column(
categories=column.as_column(dtype.categories),
codes=column.as_column(
cudf.utils.utils.scalar_broadcast_to(
categorical_column.default_na_value(),
categorical_column._default_na_value(),
categorical_column.size,
categorical_column.codes.dtype,
)
Expand Down
16 changes: 13 additions & 3 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def values_host(self) -> "np.ndarray":
"""
Return a numpy representation of the Column.
"""
if len(self) == 0:
return np.array([], dtype=self.dtype)

if self.has_nulls:
raise ValueError("Column must have no nulls.")

return self.data_array_view.copy_to_host()

@property
Expand All @@ -138,7 +144,7 @@ def values(self) -> "cupy.ndarray":
Return a CuPy representation of the Column.
"""
if len(self) == 0:
return cupy.asarray([], dtype=self.dtype)
return cupy.array([], dtype=self.dtype)

if self.has_nulls:
raise ValueError("Column must have no nulls.")
Expand Down Expand Up @@ -319,9 +325,11 @@ def _get_mask_as_column(self) -> ColumnBase:
def _memory_usage(self, **kwargs) -> int:
return self.__sizeof__()

def default_na_value(self) -> Any:
def _default_na_value(self) -> Any:
raise NotImplementedError()

# TODO: This method is decpreated and can be removed when the associated
# Frame methods are removed.
def to_gpu_array(self, fillna=None) -> "cuda.devicearray.DeviceNDArray":
"""Get a dense numba device array for the data.
Expand All @@ -337,10 +345,12 @@ def to_gpu_array(self, fillna=None) -> "cuda.devicearray.DeviceNDArray":
output size could be smaller.
"""
if fillna:
return self.fillna(self.default_na_value()).data_array_view
return self.fillna(self._default_na_value()).data_array_view
else:
return self.dropna(drop_nan=False).data_array_view

# TODO: This method is decpreated and can be removed when the associated
# Frame methods are removed.
def to_array(self, fillna=None) -> np.ndarray:
"""Get a dense numpy array for the data.
Expand Down
9 changes: 3 additions & 6 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def as_string_column(
column.column_empty(0, dtype="object", masked=False),
)

def default_na_value(self) -> DatetimeLikeScalar:
def _default_na_value(self) -> DatetimeLikeScalar:
"""Returns the default NA value for this column
"""
return np.datetime64("nat", self.time_unit)
Expand Down Expand Up @@ -491,14 +491,11 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
def _make_copy_with_na_as_null(self):
"""Return a copy with NaN values replaced with nulls."""
null = column_empty_like(self, masked=True, newsize=1)
na_value = np.datetime64("nat", self.time_unit)
out_col = cudf._lib.replace.replace(
self,
as_column(
Buffer(
np.array([self.default_na_value()], dtype=self.dtype).view(
"|u1"
)
),
Buffer(np.array([na_value], dtype=self.dtype).view("|u1")),
dtype=self.dtype,
),
null,
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _process_values_for_isin(

return lhs, rhs

def default_na_value(self) -> ScalarLike:
def _default_na_value(self) -> ScalarLike:
"""Returns the default NA value for this column
"""
dkind = self.dtype.kind
Expand Down
8 changes: 4 additions & 4 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5210,10 +5210,10 @@ def values(self) -> cupy.ndarray:
"""
Return a CuPy representation of the StringColumn.
"""
raise NotImplementedError(
"String Arrays is not yet implemented in cudf"
)
raise TypeError("String Arrays is not yet implemented in cudf")

# TODO: This method is deprecated and should be removed when the associated
# Frame methods are removed.
def to_array(self, fillna: bool = None) -> np.ndarray:
"""Get a dense numpy array for the data.
Expand Down Expand Up @@ -5409,7 +5409,7 @@ def normalize_binop_value(self, other) -> "column.ColumnBase":
else:
raise TypeError(f"cannot broadcast {type(other)}")

def default_na_value(self) -> ScalarLike:
def _default_na_value(self) -> ScalarLike:
return None

def binary_operator(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def as_numerical(self) -> "cudf.core.column.NumericalColumn":
),
)

def default_na_value(self) -> ScalarLike:
def _default_na_value(self) -> ScalarLike:
"""Returns the default NA value for this column
"""
return np.timedelta64("nat", self.time_unit)
Expand Down
75 changes: 14 additions & 61 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _init_from_series_list(self, data, columns, index):
# Setting `final_columns` to self._index so
# that the resulting `transpose` will be have
# columns set to `final_columns`
self._index = final_columns
self._index = as_index(final_columns)

transpose = self.T
else:
Expand Down Expand Up @@ -968,36 +968,6 @@ def __array_function__(self, func, types, args, kwargs):
else:
return NotImplemented

@property
def values(self):
"""
Return a CuPy representation of the DataFrame.
Only the values in the DataFrame will be returned, the axes labels will
be removed.
Returns
-------
out: cupy.ndarray
The values of the DataFrame.
"""
return cupy.asarray(self.as_gpu_matrix())

def __array__(self, dtype=None):
raise TypeError(
"Implicit conversion to a host NumPy array via __array__ is not "
"allowed, To explicitly construct a GPU matrix, consider using "
".as_gpu_matrix()\nTo explicitly construct a host "
"matrix, consider using .as_matrix()"
)

def __arrow_array__(self, type=None):
raise TypeError(
"Implicit conversion to a host PyArrow Table via __arrow_array__ "
"is not allowed, To explicitly construct a PyArrow Table, "
"consider using .to_arrow()"
)

def _get_numeric_data(self):
""" Return a dataframe with only numeric data types """
columns = [
Expand Down Expand Up @@ -2740,7 +2710,7 @@ def columns(self, columns):
if isinstance(
columns, (Series, cudf.Index, cudf.core.column.ColumnBase)
):
columns = pd.Index(columns.to_array(), tupleize_cols=is_multiindex)
columns = pd.Index(columns.to_numpy(), tupleize_cols=is_multiindex)
elif not isinstance(columns, pd.Index):
columns = pd.Index(columns, tupleize_cols=is_multiindex)

Expand Down Expand Up @@ -3724,21 +3694,11 @@ def rename(
return out.copy(deep=copy)

def as_gpu_matrix(self, columns=None, order="F"):
"""Convert to a matrix in device memory.
Parameters
----------
columns : sequence of str
List of a column names to be extracted. The order is preserved.
If None is specified, all columns are used.
order : 'F' or 'C'
Optional argument to determine whether to return a column major
(Fortran) matrix or a row major (C) matrix.
Returns
-------
A (nrow x ncol) numba device ndarray
"""
warnings.warn(
"The as_gpu_matrix method will be removed in a future cuDF "
"release. Consider using `to_cupy` instead.",
DeprecationWarning,
)
if columns is None:
columns = self._data.names

Expand Down Expand Up @@ -3782,18 +3742,11 @@ def as_gpu_matrix(self, columns=None, order="F"):
return cuda.as_cuda_array(matrix).view(dtype)

def as_matrix(self, columns=None):
"""Convert to a matrix in host memory.
Parameters
----------
columns : sequence of str
List of a column names to be extracted. The order is preserved.
If None is specified, all columns are used.
Returns
-------
A (nrow x ncol) numpy ndarray in "F" order.
"""
warnings.warn(
"The as_matrix method will be removed in a future cuDF "
"release. Consider using `to_numpy` instead.",
DeprecationWarning,
)
return self.as_gpu_matrix(columns=columns).copy_to_host()

def one_hot_encoding(
Expand Down Expand Up @@ -5754,9 +5707,9 @@ def to_records(self, index=True):
dtype = np.dtype(members)
ret = np.recarray(len(self), dtype=dtype)
if index:
ret["index"] = self.index.to_array()
ret["index"] = self.index.to_numpy()
for col in self._data.names:
ret[col] = self[col].to_array()
ret[col] = self[col].to_numpy()
return ret

@classmethod
Expand Down
Loading

0 comments on commit e597075

Please sign in to comment.