-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Align extract_partitioning_index
logic with upstream shuffling
#60
Changes from 9 commits
68af3f9
644739b
1f28a35
2f5678b
33064e8
5da12ce
25c1eb2
647406f
b48a37c
9ef4aa3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,14 +16,17 @@ | |
from itertools import combinations | ||
from typing import Iterable | ||
|
||
import dask.dataframe as dd | ||
import numpy as np | ||
import pytest | ||
import yaml | ||
from dask import config | ||
from dask.dataframe.utils import assert_eq | ||
from distributed import Client | ||
|
||
from nemo_curator import LSH, FuzzyDuplicates, FuzzyDuplicatesConfig, MinHash | ||
from nemo_curator.datasets import DocumentDataset | ||
from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import extract_partitioning_index | ||
from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from | ||
|
||
cudf = gpu_only_import("cudf") | ||
|
@@ -367,3 +370,74 @@ def test_from_yaml(self, tmpdir): | |
config = FuzzyDuplicatesConfig.from_yaml(tmpdir / "config.yaml") | ||
for param in yaml_params: | ||
assert getattr(config, param) == yaml_params[param] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"backend", | ||
[ | ||
"pandas", | ||
pytest.param( | ||
"cudf", | ||
marks=pytest.mark.gpu, | ||
), | ||
], | ||
) | ||
def test_extract_partitioning_index(backend): | ||
|
||
def add_partiton_info(df, partition_info=None): | ||
if partition_info is None: | ||
df["file_id"] = -1 | ||
else: | ||
df["file_id"] = partition_info["number"] | ||
return df | ||
|
||
with config.set({"dataframe.backend": backend}): | ||
|
||
# Create a random `unshuffled` DataFrame with a | ||
# "part_id" column to be used as the shuffle index | ||
npartitions_left = 7 | ||
unshuffled = dd.from_dict( | ||
{"part_id": np.random.randint(25, size=1000, dtype="int32")}, | ||
npartitions=npartitions_left, | ||
) | ||
|
||
# Create a `bk_mapping` DataFrame that defines | ||
# the "correct" mapping beween "part_id" and | ||
# the destination partition ("file_id") | ||
npartitions_right = 5 | ||
bk_mapping = ( | ||
dd.from_dict( | ||
{"part_id": np.arange(25, dtype="int32")}, | ||
npartitions=npartitions_right, | ||
) | ||
.shuffle("part_id") | ||
.map_partitions(add_partiton_info) | ||
.compute() | ||
) | ||
|
||
# Use `extract_partitioning_index` to calculate | ||
# the partitioning index and assign it as a new | ||
# "_partitions" column | ||
result, _ = extract_partitioning_index( | ||
unshuffled, | ||
"part_id", | ||
bk_mapping, | ||
npartitions_right, | ||
npartitions_right, | ||
Comment on lines
+425
to
+426
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm having trouble wrapping my head around the "correct" way to test the batched case. However, my sense is that this test already covers the critical requirement that we are extracting a partitioning index that is consistent with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. I don't think this specific bug impacts the batched case any differently than the reproducer here so I think it should be good to go. |
||
) | ||
|
||
# Rename the "_partitions" column, shuffle by "part_id", | ||
# and then assign a "file_id" column to reflect the final | ||
# partition of each row | ||
check = ( | ||
result.rename(columns={"_partitions": "expected_file_id"}) | ||
.shuffle( | ||
"part_id", | ||
npartitions=npartitions_right, | ||
) | ||
.map_partitions(add_partiton_info) | ||
.compute() | ||
) | ||
|
||
# Check that the real and expected partitions match | ||
assert (check["file_id"] == check["expected_file_id"]).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I moved this import into
merge_left_to_shuffled_right
, becauseshuffle_utils
currently requirescudf
/dask_cuda
, while other logic in this module does not.