Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align extract_partitioning_index logic with upstream shuffling #60

Merged
merged 10 commits into from
May 15, 2024
1 change: 1 addition & 0 deletions nemo_curator/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
# TODO: remove when dask min version gets bumped
DASK_SHUFFLE_METHOD_ARG = _dask_version > parseVersion("2024.1.0")
DASK_P2P_ERROR = _dask_version < parseVersion("2023.10.0")
DASK_SHUFFLE_CAST_DTYPE = _dask_version > parseVersion("2023.12.0")
16 changes: 16 additions & 0 deletions nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from operator import getitem

import numpy as np
import pandas as pd
from dask.base import tokenize
from dask.dataframe.core import new_dd_object
from dask.dataframe.shuffle import partitioning_index
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M

from nemo_curator._compat import DASK_SHUFFLE_CAST_DTYPE
from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import rearange_by_column_direct


Expand Down Expand Up @@ -129,6 +131,19 @@ def extract_partitioning_index(
# a partition-wise merge between `left_df` and `right_df`.
# We call this `global_partitioning_index`:

if DASK_SHUFFLE_CAST_DTYPE:
# Need to use the same type-casting logic as `shuffle`
dtypes = {}
for col, dtype in left_df[merge_on].dtypes.items():
if pd.api.types.is_numeric_dtype(dtype):
dtypes[col] = np.float64
if not dtypes:
dtypes = None
cast_dtype = {"cast_dtype": dtypes}
else:
# `cast_dtype` argument doesn't exist yet
cast_dtype = {}

num_bucket_files = bk_mapping.file_id.max() + 1
global_partitioning_index = left_df[merge_on].map_partitions(
partitioning_index,
Expand All @@ -137,6 +152,7 @@ def extract_partitioning_index(
enforce_metadata=False,
transform_divisions=False,
align_dataframes=False,
**cast_dtype,
)

if total_bucket_partitions < num_bucket_files:
Expand Down