Skip to content

Commit

Permalink
Enable round in cudf for DataFrame and Series (#7022)
Browse files Browse the repository at this point in the history
This enables round for DataFrames and Series using the libcudf round implementation and removes the old numba round implementation.

Closes #1270

Authors:
  - @ChrisJar

Approvers:
  - Ashwin Srinath (@shwina)
  - Michael Wang (@isVoid)
  - Ram (Ramakrishna Prabhu) (@rgsl888prabhu)
  - GALI PREM SAGAR (@galipremsagar)

URL: #7022
  • Loading branch information
ChrisJar authored Jan 20, 2021
1 parent 95059b8 commit a51caa5
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 105 deletions.
3 changes: 2 additions & 1 deletion python/cudf/cudf/_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
import numpy as np

from . import (
Expand All @@ -23,6 +23,7 @@
replace,
reshape,
rolling,
round,
search,
sort,
stream_compaction,
Expand Down
19 changes: 19 additions & 0 deletions python/cudf/cudf/_lib/cpp/round.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

from libc.stdint cimport int32_t
from libcpp.memory cimport unique_ptr

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view

cdef extern from "cudf/round.hpp" namespace "cudf" nogil:

ctypedef enum rounding_method "cudf::rounding_method":
HALF_UP "cudf::rounding_method::HALF_UP"
HALF_EVEN "cudf::rounding_method::HALF_EVEN"

cdef unique_ptr[column] round (
const column_view& input,
int32_t decimal_places,
rounding_method method,
) except +
42 changes: 42 additions & 0 deletions python/cudf/cudf/_lib/round.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

from cudf._lib.column cimport Column

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.round cimport (
rounding_method as cpp_rounding_method,
round as cpp_round
)


def round(Column input_col, int decimal_places=0):
"""
Round column values to the given number of decimal places
Parameters
----------
input_col : Column whose values will be rounded
decimal_places : The number or decimal places to round to
Returns
-------
A Column with values rounded to the given number of decimal places
"""

cdef column_view input_col_view = input_col.view()
cdef unique_ptr[column] c_result

with nogil:
c_result = move(
cpp_round(
input_col_view,
decimal_places,
cpp_rounding_method.HALF_EVEN,
)
)

return Column.from_unique_ptr(move(c_result))
16 changes: 4 additions & 12 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2020, NVIDIA CORPORATION.
# Copyright (c) 2018-2021, NVIDIA CORPORATION.

from numbers import Number

Expand Down Expand Up @@ -342,17 +342,9 @@ def corr(self, other):
return cov / lhs_std / rhs_std

def round(self, decimals=0):
if decimals < 0:
msg = "Decimal values < 0 are not yet supported."
raise NotImplementedError(msg)

if np.issubdtype(self.dtype, np.integer):
return self

data = Buffer(
cudautils.apply_round(self.data_array_view, decimals).view("|u1")
)
return column.build_column(data=data, dtype=self.dtype, mask=self.mask)
"""Round the values in the Column to the given number of decimals.
"""
return libcudf.round.round(self, decimal_places=decimals)

def applymap(self, udf, out_dtype=None):
"""Apply an element-wise function to transform the values in the Column.
Expand Down
111 changes: 111 additions & 0 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,117 @@ def __arrow_array__(self, type=None):
"consider using .to_arrow()"
)

def round(self, decimals=0):
"""
Round a DataFrame to a variable number of decimal places.
Parameters
----------
decimals : int, dict, Series
Number of decimal places to round each column to. If an int is
given, round each column to the same number of places.
Otherwise dict and Series round to variable numbers of places.
Column names should be in the keys if `decimals` is a
dict-like, or in the index if `decimals` is a Series. Any
columns not included in `decimals` will be left as is. Elements
of `decimals` which are not columns of the input will be
ignored.
Returns
-------
DataFrame
A DataFrame with the affected columns rounded to the specified
number of decimal places.
Examples
--------
>>> df = cudf.DataFrame(
[(.21, .32), (.01, .67), (.66, .03), (.21, .18)],
... columns=['dogs', 'cats']
... )
>>> df
dogs cats
0 0.21 0.32
1 0.01 0.67
2 0.66 0.03
3 0.21 0.18
By providing an integer each column is rounded to the same number
of decimal places
>>> df.round(1)
dogs cats
0 0.2 0.3
1 0.0 0.7
2 0.7 0.0
3 0.2 0.2
With a dict, the number of places for specific columns can be
specified with the column names as key and the number of decimal
places as value
>>> df.round({'dogs': 1, 'cats': 0})
dogs cats
0 0.2 0.0
1 0.0 1.0
2 0.7 0.0
3 0.2 0.0
Using a Series, the number of places for specific columns can be
specified with the column names as index and the number of
decimal places as value
>>> decimals = cudf.Series([0, 1], index=['cats', 'dogs'])
>>> df.round(decimals)
dogs cats
0 0.2 0.0
1 0.0 1.0
2 0.7 0.0
3 0.2 0.0
"""

if isinstance(decimals, cudf.Series):
decimals = decimals.to_pandas()

if isinstance(decimals, (dict, pd.Series)):
if (
isinstance(decimals, pd.Series)
and not decimals.index.is_unique
):
raise ValueError("Index of decimals must be unique")

cols = {
name: col.round(decimals[name])
if (
name in decimals.keys()
and pd.api.types.is_numeric_dtype(col.dtype)
)
else col.copy(deep=True)
for name, col in self._data.items()
}
elif isinstance(decimals, int):
cols = {
name: col.round(decimals)
if pd.api.types.is_numeric_dtype(col.dtype)
else col.copy(deep=True)
for name, col in self._data.items()
}
else:
raise TypeError(
"decimals must be an integer, a dict-like or a Series"
)

return self.__class__._from_table(
Frame(
data=cudf.core.column_accessor.ColumnAccessor(
cols,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
)
),
index=self._index,
)

@annotate("SAMPLE", color="orange", domain="cudf_python")
def sample(
self,
Expand Down
27 changes: 25 additions & 2 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2020, NVIDIA CORPORATION.
# Copyright (c) 2018-2021, NVIDIA CORPORATION.
import pickle
import warnings
from collections import abc as abc
Expand Down Expand Up @@ -3506,8 +3506,31 @@ def mode(self, dropna=True):
return Series(val_counts.index.sort_values(), name=self.name)

def round(self, decimals=0):
"""Round a Series to a configurable number of decimal places.
"""
Round each value in a Series to the given number of decimals.
Parameters
----------
decimals : int, default 0
Number of decimal places to round to. If decimals is negative,
it specifies the number of positions to the left of the decimal
point.
Returns
-------
Series
Rounded values of the Series.
Examples
--------
>>> s = cudf.Series([0.1, 1.4, 2.9])
>>> s.round()
0 0.0
1 1.0
2 3.0
dtype: float64
"""

return Series(
self._column.round(decimals=decimals),
name=self.name,
Expand Down
111 changes: 43 additions & 68 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3243,86 +3243,61 @@ def test_ndim():


@pytest.mark.parametrize(
"arr",
[
np.random.normal(-100, 100, 1000),
np.random.randint(-50, 50, 1000),
np.zeros(100),
np.repeat([-0.6459412758761901], 100),
np.repeat(np.nan, 100),
np.array([1.123, 2.343, np.nan, 0.0]),
],
)
@pytest.mark.parametrize(
"decimal",
"decimals",
[
-3,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
pytest.param(
-1,
marks=[
pytest.mark.xfail(reason="NotImplementedError: decimals < 0")
],
),
pd.Series([1, 4, 3, -6], index=["w", "x", "y", "z"]),
gd.Series([-4, -2, 12], index=["x", "y", "z"]),
{"w": -1, "x": 15, "y": 2},
],
)
def test_round(arr, decimal):
pser = pd.Series(arr)
ser = gd.Series(arr)
result = ser.round(decimal)
expected = pser.round(decimal)
def test_dataframe_round(decimals):
pdf = pd.DataFrame(
{
"w": np.arange(0.5, 10.5, 1),
"x": np.random.normal(-100, 100, 10),
"y": np.array(
[
14.123,
2.343,
np.nan,
0.0,
-8.302,
np.nan,
94.313,
-112.236,
-8.029,
np.nan,
]
),
"z": np.repeat([-0.6459412758761901], 10),
}
)
gdf = gd.DataFrame.from_pandas(pdf)

if isinstance(decimals, gd.Series):
pdecimals = decimals.to_pandas()
else:
pdecimals = decimals

result = gdf.round(decimals)
expected = pdf.round(pdecimals)
assert_eq(result, expected)

# with nulls, maintaining existing null mask
arr = arr.astype("float64") # for pandas nulls
mask = np.random.randint(0, 2, arr.shape[0])
arr[mask == 1] = np.nan
for c in pdf.columns:
arr = pdf[c].to_numpy().astype("float64") # for pandas nulls
arr.ravel()[np.random.choice(10, 5, replace=False)] = np.nan
pdf[c] = gdf[c] = arr

pser = pd.Series(arr)
ser = gd.Series(arr)
result = ser.round(decimal)
expected = pser.round(decimal)
result = gdf.round(decimals)
expected = pdf.round(pdecimals)

assert_eq(result, expected)
np.array_equal(ser.nullmask.to_array(), result.to_array())


@pytest.mark.parametrize(
"series",
[
gd.Series([1.0, None, np.nan, 4.0], nan_as_null=False),
gd.Series([1.24430, None, np.nan, 4.423530], nan_as_null=False),
gd.Series([1.24430, np.nan, 4.423530], nan_as_null=False),
gd.Series([-1.24430, np.nan, -4.423530], nan_as_null=False),
gd.Series(np.repeat(np.nan, 100)),
],
)
@pytest.mark.parametrize("decimal", [0, 1, 2, 3])
def test_round_nan_as_null_false(series, decimal):
pser = series.to_pandas()
ser = gd.Series(series)
result = ser.round(decimal)
expected = pser.round(decimal)
np.testing.assert_array_almost_equal(
result.to_pandas(), expected, decimal=10
)
for c in gdf.columns:
np.array_equal(gdf[c].nullmask.to_array(), result[c].to_array())


@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit a51caa5

Please sign in to comment.