Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Column equality testing fixes #10011

102 changes: 80 additions & 22 deletions python/cudf/cudf/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,31 @@
import pandas as pd

import cudf
from cudf.api.types import is_categorical_dtype, is_numeric_dtype
from cudf._lib.unary import is_nan
from cudf.api.types import (
is_categorical_dtype,
is_decimal_dtype,
is_interval_dtype,
is_list_dtype,
is_numeric_dtype,
is_string_dtype,
is_struct_dtype,
)
from cudf.core._compat import PANDAS_GE_110


def dtype_can_compare_equal_to_other(dtype):
# return True if values of this dtype can compare
# as equal to equal values of a different dtype
return not (
is_string_dtype(dtype)
or is_list_dtype(dtype)
or is_struct_dtype(dtype)
or is_decimal_dtype(dtype)
or is_interval_dtype(dtype)
)


def _check_isinstance(left, right, obj):
if not isinstance(left, obj):
raise AssertionError(
Expand Down Expand Up @@ -146,6 +167,9 @@ def assert_column_equal(
msg1 = f"{left.dtype}"
msg2 = f"{right.dtype}"
raise_assert_detail(obj, "Dtypes are different", msg1, msg2)
else:
if left.null_count == len(left) and right.null_count == len(right):
return True

if check_datetimelike_compat:
if np.issubdtype(left.dtype, np.datetime64):
Expand Down Expand Up @@ -201,39 +225,74 @@ def assert_column_equal(
):
left = left.astype(left.categories.dtype)
right = right.astype(right.categories.dtype)

columns_equal = False
try:
columns_equal = (
(
cp.all(left.isnull().values == right.isnull().values)
and cp.allclose(
if left.size == right.size == 0:
columns_equal = True
elif not (
(
not dtype_can_compare_equal_to_other(left.dtype)
and is_numeric_dtype(right)
)
or (
is_numeric_dtype(left)
and not dtype_can_compare_equal_to_other(right)
)
):
try:
# nulls must be in the same places for all dtypes
columns_equal = cp.all(
left.isnull().values == right.isnull().values
)

if is_numeric_dtype(left) and columns_equal and not check_exact:
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
# non-null values must be the same
columns_equal = cp.allclose(
left[left.isnull().unary_operator("not")].values,
right[right.isnull().unary_operator("not")].values,
)
)
if not check_exact and is_numeric_dtype(left)
else left.equals(right)
)
except TypeError as e:
if str(e) != "Categoricals can only compare with the same type":
raise e
if is_categorical_dtype(left) and is_categorical_dtype(right):
left = left.astype(left.categories.dtype)
right = right.astype(right.categories.dtype)
if columns_equal and (
left.dtype.kind == right.dtype.kind == "f"
):
columns_equal = cp.all(
is_nan(left).values == is_nan(right).values
)
else:
columns_equal = left.equals(right)
except TypeError as e:
if str(e) != "Categoricals can only compare with the same type":
raise e
else:
columns_equal = False
if is_categorical_dtype(left) and is_categorical_dtype(right):
left = left.astype(left.categories.dtype)
right = right.astype(right.categories.dtype)
if not columns_equal:
msg1 = f"{left.values_host}"
msg2 = f"{right.values_host}"
ldata = str([val for val in left.to_pandas(nullable=True)])
rdata = str([val for val in right.to_pandas(nullable=True)])
msg1 = f"{ldata}"
msg2 = f"{rdata}"
vyasr marked this conversation as resolved.
Show resolved Hide resolved
try:
diff = left.apply_boolean_mask(left != right).size
diff = 0
for i in range(left.size):
if not null_safe_scalar_equals(left[i], right[i]):
diff += 1
diff = diff * 100.0 / left.size
except BaseException:
diff = 100.0
raise_assert_detail(
obj, f"values are different ({np.round(diff, 5)} %)", msg1, msg2,
obj,
f"values are different ({np.round(diff, 5)} %)",
{ldata},
{rdata},
)


def null_safe_scalar_equals(left, right):
if left in {cudf.NA, np.nan} or right in {cudf.NA, np.nan}:
return left is right
return left == right
vyasr marked this conversation as resolved.
Show resolved Hide resolved


def assert_index_equal(
left,
right,
Expand Down Expand Up @@ -358,7 +417,6 @@ def assert_index_equal(
obj=mul_obj,
)
return

assert_column_equal(
left._columns[0],
right._columns[0],
Expand Down
32 changes: 32 additions & 0 deletions python/cudf/cudf/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import pytest

import cudf
from cudf.core.column.column import as_column, full
from cudf.testing import (
assert_frame_equal,
assert_index_equal,
assert_series_equal,
)
from cudf.testing._utils import NUMERIC_TYPES, OTHER_TYPES, assert_eq
from cudf.testing.testing import assert_column_equal


@pytest.mark.parametrize("rdata", [[1, 2, 5], [1, 2, 6], [1, 2, 5, 6]])
Expand Down Expand Up @@ -119,6 +121,36 @@ def test_basic_assert_series_equal(
)


@pytest.mark.parametrize(
"other",
[
as_column(["1", "2", "3"]),
as_column([[1], [2], [3]]),
as_column([{"a": 1}, {"a": 2}, {"a": 3}]),
],
)
def test_assert_column_equal_dtype_edge_cases(other):
# string series should be 100% different
# even when the elements are the same
base = as_column([1, 2, 3])

# for these dtypes, the diff should always be 100% regardless of the values
with pytest.raises(
AssertionError, match=r".*values are different \(100.0 %\).*"
):
assert_column_equal(base, other, check_dtype=False)

# the exceptions are the empty and all null cases
assert_column_equal(base[:0], other[:0], check_dtype=False)
assert_column_equal(other[:0], base[:0], check_dtype=False)

base = full(len(base), fill_value=cudf.NA, dtype=base.dtype)
other = full(len(other), fill_value=cudf.NA, dtype=other.dtype)

assert_column_equal(base, other, check_dtype=False)
assert_column_equal(other, base, check_dtype=False)


@pytest.mark.parametrize(
"rdtype", [["int8", "int16", "int64"], ["int64", "int16", "int8"]]
)
Expand Down