Skip to content

Commit

Permalink
Revert "Remove tests that rely on MultiIndex merging."
Browse files Browse the repository at this point in the history
This reverts commit 0ef8213.
  • Loading branch information
vyasr committed Apr 19, 2022
1 parent b7e09a3 commit 2c83f3e
Showing 1 changed file with 126 additions and 0 deletions.
126 changes: 126 additions & 0 deletions python/cudf/cudf/tests/test_joining.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,119 @@ def test_categorical_typecast_outer_one_cat(dtype):
assert result["key"].dtype == left["key"].dtype.categories.dtype


@pytest.mark.parametrize(
("lhs", "rhs"),
[
(["a", "b"], ["a"]),
(["a"], ["a", "b"]),
(["a", "b"], ["b"]),
(["b"], ["a", "b"]),
(["a"], ["a"]),
],
)
@pytest.mark.parametrize("how", ["left", "right", "outer", "inner"])
@pytest.mark.parametrize("level", ["a", "b", 0, 1])
def test_index_join(lhs, rhs, how, level):
l_pdf = pd.DataFrame({"a": [2, 3, 1, 4], "b": [3, 7, 8, 1]})
r_pdf = pd.DataFrame({"a": [1, 5, 4, 0], "b": [3, 9, 8, 4]})
l_df = cudf.from_pandas(l_pdf)
r_df = cudf.from_pandas(r_pdf)
p_lhs = l_pdf.set_index(lhs).index
p_rhs = r_pdf.set_index(rhs).index
g_lhs = l_df.set_index(lhs).index
g_rhs = r_df.set_index(rhs).index

expected = p_lhs.join(p_rhs, level=level, how=how).to_frame(index=False)
got = g_lhs.join(g_rhs, level=level, how=how).to_frame(index=False)

assert_join_results_equal(expected, got, how=how)


def test_index_join_corner_cases():
l_pdf = pd.DataFrame({"a": [2, 3, 1, 4], "b": [3, 7, 8, 1]})
r_pdf = pd.DataFrame(
{"a": [1, 5, 4, 0], "b": [3, 9, 8, 4], "c": [2, 3, 6, 0]}
)
l_df = cudf.from_pandas(l_pdf)
r_df = cudf.from_pandas(r_pdf)

# Join when column name doesn't match with level
lhs = ["a", "b"]
# level and rhs don't match
rhs = ["c"]
level = "b"
how = "outer"
p_lhs = l_pdf.set_index(lhs).index
p_rhs = r_pdf.set_index(rhs).index
g_lhs = l_df.set_index(lhs).index
g_rhs = r_df.set_index(rhs).index
expected = p_lhs.join(p_rhs, level=level, how=how).to_frame(index=False)
got = g_lhs.join(g_rhs, level=level, how=how).to_frame(index=False)

assert_join_results_equal(expected, got, how=how)

# sort is supported only in case of two non-MultiIndex join
# Join when column name doesn't match with level
lhs = ["a"]
# level and rhs don't match
rhs = ["a"]
level = "b"
how = "left"
p_lhs = l_pdf.set_index(lhs).index
p_rhs = r_pdf.set_index(rhs).index
g_lhs = l_df.set_index(lhs).index
g_rhs = r_df.set_index(rhs).index
expected = p_lhs.join(p_rhs, how=how, sort=True)
got = g_lhs.join(g_rhs, how=how, sort=True)

assert_join_results_equal(expected, got, how=how)

# Pandas Index.join on categorical column returns generic column
# but cudf will be returning a categorical column itself.
lhs = ["a", "b"]
rhs = ["a"]
level = "a"
how = "inner"
l_df["a"] = l_df["a"].astype("category")
r_df["a"] = r_df["a"].astype("category")
p_lhs = l_pdf.set_index(lhs).index
p_rhs = r_pdf.set_index(rhs).index
g_lhs = l_df.set_index(lhs).index
g_rhs = r_df.set_index(rhs).index
expected = p_lhs.join(p_rhs, level=level, how=how).to_frame(index=False)
got = g_lhs.join(g_rhs, level=level, how=how).to_frame(index=False)

got["a"] = got["a"].astype(expected["a"].dtype)

assert_join_results_equal(expected, got, how=how)


def test_index_join_exception_cases():
l_df = cudf.DataFrame({"a": [2, 3, 1, 4], "b": [3, 7, 8, 1]})
r_df = cudf.DataFrame(
{"a": [1, 5, 4, 0], "b": [3, 9, 8, 4], "c": [2, 3, 6, 0]}
)

# Join between two MultiIndex
lhs = ["a", "b"]
rhs = ["a", "c"]
level = "a"
how = "outer"
g_lhs = l_df.set_index(lhs).index
g_rhs = r_df.set_index(rhs).index

with pytest.raises(TypeError):
g_lhs.join(g_rhs, level=level, how=how)

# Improper level value, level should be an int or scalar value
level = ["a"]
rhs = ["a"]
g_lhs = l_df.set_index(lhs).index
g_rhs = r_df.set_index(rhs).index
with pytest.raises(ValueError):
g_lhs.join(g_rhs, level=level, how=how)


def test_typecast_on_join_indexes():
join_data_l = cudf.Series([1, 2, 3, 4, 5], dtype="int8")
join_data_r = cudf.Series([1, 2, 3, 4, 6], dtype="int32")
Expand Down Expand Up @@ -2055,3 +2168,16 @@ def test_join_redundant_params():
lhs.merge(rhs, right_on="a", left_index=True, right_index=True)
with pytest.raises(ValueError):
lhs.merge(rhs, left_on="c", right_on="b")


def test_join_multiindex_index():
# test joining a MultiIndex with an Index with overlapping name
lhs = (
cudf.DataFrame({"a": [2, 3, 1], "b": [3, 4, 2]})
.set_index(["a", "b"])
.index
)
rhs = cudf.DataFrame({"a": [1, 4, 3]}).set_index("a").index
expect = lhs.to_pandas().join(rhs.to_pandas(), how="inner")
got = lhs.join(rhs, how="inner")
assert_join_results_equal(expect, got, how="inner")

0 comments on commit 2c83f3e

Please sign in to comment.