Skip to content

Commit

Permalink
Add cudf.DataFrame.applymap (#10542)
Browse files Browse the repository at this point in the history
Naive implementation of `DataFrame.applymap` that just calls `apply` in a loop over columns.

This could theoretically be made much faster within our framework. This requires at worst `N` compilations and `M` kernel launches, where `N` is the number of different dtypes in the data, and `M` is the number of total columns. We could however as an improvement to this launch just one kernel that populates the entire output data. This would still suffer from the compilation bottleneck however, since the function must be compiled in order for an output dtype to be determined, and this will need to be done for each distinct dtype within the data.

Part of #10169

Authors:
  - https://github.com/brandon-b-miller
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #10542
  • Loading branch information
brandon-b-miller authored Apr 13, 2022
1 parent c72868e commit ce56bc3
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/cudf/source/api_docs/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Function application, GroupBy & window
:toctree: api/

DataFrame.apply
DataFrame.applymap
DataFrame.apply_chunks
DataFrame.apply_rows
DataFrame.pipe
Expand Down
64 changes: 64 additions & 0 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections.abc import Iterable, Mapping, Sequence
from typing import (
Any,
Callable,
Dict,
List,
MutableMapping,
Expand All @@ -25,6 +26,7 @@
)

import cupy
import numba
import numpy as np
import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -3708,6 +3710,68 @@ def apply(

return self._apply(func, _get_row_kernel, *args, **kwargs)

def applymap(
self,
func: Callable[[Any], Any],
na_action: Union[str, None] = None,
**kwargs,
) -> DataFrame:

"""
Apply a function to a Dataframe elementwise.
This method applies a function that accepts and returns a scalar
to every element of a DataFrame.
Parameters
----------
func : callable
Python function, returns a single value from a single value.
na_action : {None, 'ignore'}, default None
If 'ignore', propagate NaN values, without passing them to func.
Returns
-------
DataFrame
Transformed DataFrame.
"""

if kwargs:
raise NotImplementedError(
"DataFrame.applymap does not yet support **kwargs."
)

if na_action not in {"ignore", None}:
raise ValueError(
f"na_action must be 'ignore' or None. Got {repr(na_action)}"
)

if na_action == "ignore":
devfunc = numba.cuda.jit(device=True)(func)

# promote to a null-ignoring function
# this code is never run in python, it only
# exists to provide numba with the correct
# bytecode to generate the equivalent PTX
# as a null-ignoring version of the function
def _func(x): # pragma: no cover
if x is cudf.NA:
return cudf.NA
else:
return devfunc(x)

else:
_func = func

# TODO: naive implementation
# this could be written as a single kernel
result = {}
for name, col in self._data.items():
apply_sr = Series._from_data({None: col})
result[name] = apply_sr.apply(_func)

return DataFrame._from_data(result, index=self.index)

@_cudf_nvtx_annotate
@applyutils.doc_apply()
def apply_rows(
Expand Down
44 changes: 43 additions & 1 deletion python/cudf/cudf/tests/test_applymap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pytest

from cudf import Series
from cudf import NA, DataFrame, Series
from cudf.testing import _utils as utils


Expand Down Expand Up @@ -58,3 +58,45 @@ def test_applymap_change_out_dtype():
expect = np.array(data, dtype=float)
got = out.to_numpy()
np.testing.assert_array_equal(expect, got)


@pytest.mark.parametrize(
"data",
[
{"a": [1, 2, 3], "b": [4, 5, 6]},
{"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]},
{"a": [1, 2, 3], "b": [True, False, True]},
{"a": [1, NA, 2], "b": [NA, 4, NA]},
],
)
@pytest.mark.parametrize(
"func",
[
lambda x: x + 1,
lambda x: x - 0.5,
lambda x: 2 if x is NA else 2 + (x + 1) / 4.1,
lambda x: 42,
],
)
@pytest.mark.parametrize("na_action", [None, "ignore"])
def test_applymap_dataframe(data, func, na_action):
gdf = DataFrame(data)
pdf = gdf.to_pandas(nullable=True)

expect = pdf.applymap(func, na_action=na_action)
got = gdf.applymap(func, na_action=na_action)

utils.assert_eq(expect, got, check_dtype=False)


def test_applymap_raise_cases():
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})

def f(x, some_kwarg=0):
return x + some_kwarg

with pytest.raises(NotImplementedError):
df.applymap(f, some_kwarg=1)

with pytest.raises(ValueError):
df.applymap(f, na_action="some_invalid_option")
29 changes: 29 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_applymap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2022, NVIDIA CORPORATION.

import pytest
from pandas import NA

from dask import dataframe as dd

from dask_cudf.tests.utils import _make_random_frame


@pytest.mark.parametrize(
"func",
[
lambda x: x + 1,
lambda x: x - 0.5,
lambda x: 2 if x is NA else 2 + (x + 1) / 4.1,
lambda x: 42,
],
)
@pytest.mark.parametrize("has_na", [True, False])
def test_applymap_basic(func, has_na):
size = 2000
pdf, dgdf = _make_random_frame(size, include_na=False)

dpdf = dd.from_pandas(pdf, npartitions=dgdf.npartitions)

expect = dpdf.applymap(func)
got = dgdf.applymap(func)
dd.assert_eq(expect, got, check_dtype=False)
13 changes: 4 additions & 9 deletions python/dask_cudf/dask_cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2022, NVIDIA CORPORATION.

import operator

import numpy as np
Expand All @@ -8,6 +10,8 @@

import cudf

from dask_cudf.tests.utils import _make_random_frame


def _make_empty_frame(npartitions=2):
df = pd.DataFrame({"x": [], "y": []})
Expand All @@ -16,15 +20,6 @@ def _make_empty_frame(npartitions=2):
return dgf


def _make_random_frame(nelem, npartitions=2):
df = pd.DataFrame(
{"x": np.random.random(size=nelem), "y": np.random.random(size=nelem)}
)
gdf = cudf.DataFrame.from_pandas(df)
dgf = dd.from_pandas(gdf, npartitions=npartitions)
return df, dgf


def _make_random_frame_float(nelem, npartitions=2):
df = pd.DataFrame(
{
Expand Down
21 changes: 21 additions & 0 deletions python/dask_cudf/dask_cudf/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2022, NVIDIA CORPORATION.

import numpy as np
import pandas as pd

import dask.dataframe as dd

import cudf


def _make_random_frame(nelem, npartitions=2, include_na=False):
df = pd.DataFrame(
{"x": np.random.random(size=nelem), "y": np.random.random(size=nelem)}
)

if include_na:
df["x"][::2] = pd.NA

gdf = cudf.DataFrame.from_pandas(df)
dgf = dd.from_pandas(gdf, npartitions=npartitions)
return df, dgf

0 comments on commit ce56bc3

Please sign in to comment.