Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align extract_partitioning_index logic with upstream shuffling #60

Merged
merged 10 commits into from
May 15, 2024
1 change: 1 addition & 0 deletions nemo_curator/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
# TODO: remove when dask min version gets bumped
DASK_SHUFFLE_METHOD_ARG = _dask_version > parseVersion("2024.1.0")
DASK_P2P_ERROR = _dask_version < parseVersion("2023.10.0")
DASK_SHUFFLE_CAST_DTYPE = _dask_version > parseVersion("2023.12.0")
25 changes: 23 additions & 2 deletions nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from operator import getitem

import numpy as np
import pandas as pd
from dask.base import tokenize
from dask.dataframe.core import new_dd_object
from dask.dataframe.shuffle import partitioning_index
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M

from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import rearange_by_column_direct
from nemo_curator._compat import DASK_SHUFFLE_CAST_DTYPE


def _split_part(part, nsplits):
Expand Down Expand Up @@ -129,6 +130,21 @@ def extract_partitioning_index(
# a partition-wise merge between `left_df` and `right_df`.
# We call this `global_partitioning_index`:

if DASK_SHUFFLE_CAST_DTYPE:
# Need to use the same type-casting logic as `shuffle`
dtypes = {}
if not isinstance(merge_on, list):
merge_on = [merge_on]
for col, dtype in left_df[merge_on].dtypes.items():
if pd.api.types.is_numeric_dtype(dtype):
dtypes[col] = np.float64
if not dtypes:
dtypes = None
cast_dtype = {"cast_dtype": dtypes}
else:
# `cast_dtype` argument doesn't exist yet
cast_dtype = {}

num_bucket_files = bk_mapping.file_id.max() + 1
global_partitioning_index = left_df[merge_on].map_partitions(
partitioning_index,
Expand All @@ -137,6 +153,7 @@ def extract_partitioning_index(
enforce_metadata=False,
transform_divisions=False,
align_dataframes=False,
**cast_dtype,
)

if total_bucket_partitions < num_bucket_files:
Expand All @@ -157,7 +174,7 @@ def extract_partitioning_index(
# want to send the rows of `left_df` to the partition
# indices encoded in `global_partitioning_index`. Instead, we
# need to take a modulus with `parts_per_bucket_batch` to
# define a `"_partitoins"` column.
# define a `"_partitions"` column.
left_df["_partitions"] = global_partitioning_index % parts_per_bucket_batch

return left_df, global_partitioning_index
Expand Down Expand Up @@ -195,6 +212,10 @@ def merge_left_to_shuffled_right(
subset_bucket_df,
merge_on,
):
from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import (
rearange_by_column_direct,
)
Comment on lines +215 to +217
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I moved this import into merge_left_to_shuffled_right, because shuffle_utils currently requires cudf/dask_cuda, while other logic in this module does not.


# We are merging an unshuffled batch of "left" partitions
# with a shuffled batch of "right" partitions. To minimize
# data movement, we can manaully rerrange the "left" batch
Expand Down
74 changes: 74 additions & 0 deletions tests/test_fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
from itertools import combinations
from typing import Iterable

import dask.dataframe as dd
import numpy as np
import pytest
import yaml
from dask import config
from dask.dataframe.utils import assert_eq
from distributed import Client

from nemo_curator import LSH, FuzzyDuplicates, FuzzyDuplicatesConfig, MinHash
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import extract_partitioning_index
from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from

cudf = gpu_only_import("cudf")
Expand Down Expand Up @@ -367,3 +370,74 @@ def test_from_yaml(self, tmpdir):
config = FuzzyDuplicatesConfig.from_yaml(tmpdir / "config.yaml")
for param in yaml_params:
assert getattr(config, param) == yaml_params[param]


@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param(
"cudf",
marks=pytest.mark.gpu,
),
],
)
def test_extract_partitioning_index(backend):

def add_partition_info(df, partition_info=None):
if partition_info is None:
df["file_id"] = -1
else:
df["file_id"] = partition_info["number"]
return df

with config.set({"dataframe.backend": backend}):

# Create a random `unshuffled` DataFrame with a
# "part_id" column to be used as the shuffle index
npartitions_left = 7
unshuffled = dd.from_dict(
{"part_id": np.random.randint(25, size=1000, dtype="int32")},
npartitions=npartitions_left,
)

# Create a `bk_mapping` DataFrame that defines
# the "correct" mapping beween "part_id" and
# the destination partition ("file_id")
npartitions_right = 5
bk_mapping = (
dd.from_dict(
{"part_id": np.arange(25, dtype="int32")},
npartitions=npartitions_right,
)
.shuffle("part_id")
.map_partitions(add_partition_info)
.compute()
)

# Use `extract_partitioning_index` to calculate
# the partitioning index and assign it as a new
# "_partitions" column
result, _ = extract_partitioning_index(
unshuffled,
"part_id",
bk_mapping,
npartitions_right,
npartitions_right,
Comment on lines +425 to +426
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble wrapping my head around the "correct" way to test the batched case. However, my sense is that this test already covers the critical requirement that we are extracting a partitioning index that is consistent with ddf.shuffle.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I don't think this specific bug impacts the batched case any differently than the reproducer here so I think it should be good to go.

)

# Rename the "_partitions" column, shuffle by "part_id",
# and then assign a "file_id" column to reflect the final
# partition of each row
check = (
result.rename(columns={"_partitions": "expected_file_id"})
.shuffle(
"part_id",
npartitions=npartitions_right,
)
.map_partitions(add_partition_info)
.compute()
)

# Check that the real and expected partitions match
assert (check["file_id"] == check["expected_file_id"]).all()