diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index af45f290c..37592b188 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -13,7 +13,6 @@ # limitations under the License. import dask.dataframe as dd -import dask_cudf from nemo_curator.utils.distributed_utils import read_data, write_to_disk from nemo_curator.utils.file_utils import get_all_files_paths_under @@ -182,10 +181,7 @@ def _read_json_or_parquet( ) dfs.append(df) - if backend == "cudf": - raw_data = dask_cudf.concat(dfs, ignore_unknown_divisions=True) - else: - raw_data = dd.concat(dfs, ignore_unknown_divisions=True) + raw_data = dd.concat(dfs, ignore_unknown_divisions=True) elif isinstance(input_files, str): # Single file diff --git a/nemo_curator/gpu_deduplication/utils.py b/nemo_curator/gpu_deduplication/utils.py index ed69477be..f6faefe77 100644 --- a/nemo_curator/gpu_deduplication/utils.py +++ b/nemo_curator/gpu_deduplication/utils.py @@ -13,84 +13,8 @@ # limitations under the License. import argparse -import logging -import os -import socket -from contextlib import nullcontext from time import time -import cudf -from dask_cuda import LocalCUDACluster -from distributed import Client, performance_report - - -def create_logger(rank, log_file, name="logger", log_level=logging.INFO): - # Create the logger - logger = logging.getLogger(name) - logger.setLevel(log_level) - - myhost = socket.gethostname() - - extra = {"host": myhost, "rank": rank} - formatter = logging.Formatter( - "%(asctime)s | %(host)s | Rank %(rank)s | %(message)s" - ) - - # File handler for output - file_handler = logging.FileHandler(log_file, mode="a") - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger = logging.LoggerAdapter(logger, extra) - - return logger - - -# TODO: Remove below to use nemo_curator.distributed_utils.get_client -def get_client(args) -> Client: - if args.scheduler_address: - if args.scheduler_file: - raise ValueError( - "Only one of scheduler_address or scheduler_file can be provided" - ) - else: - return Client(address=args.scheduler_address, timeout="30s") - elif args.scheduler_file: - return Client(scheduler_file=args.scheduler_file, timeout="30s") - else: - extra_kwargs = ( - { - "enable_tcp_over_ucx": True, - "enable_nvlink": True, - "enable_infiniband": False, - "enable_rdmacm": False, - } - if args.nvlink_only and args.protocol == "ucx" - else {} - ) - - cluster = LocalCUDACluster( - rmm_pool_size=args.rmm_pool_size, - protocol=args.protocol, - rmm_async=True, - **extra_kwargs, - ) - return Client(cluster) - - -def performance_report_if(path=None, report_name="dask-profile.html"): - if path is not None: - return performance_report(os.path.join(path, report_name)) - else: - return nullcontext() - - -# TODO: Remove below to use nemo_curator.distributed_utils._enable_spilling -def enable_spilling(): - """ - Enables spilling to host memory for cudf - """ - cudf.set_option("spill", True) - def get_num_workers(client): """ diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index d845441f3..116fc2aff 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -15,27 +15,33 @@ from .add_id import AddId from .exact_dedup import ExactDuplicates from .filter import Filter, Score, ScoreFilter -from .fuzzy_dedup import LSH, MinHash + from .meta import Sequential from .modify import Modify from .task import TaskDecontamination -# Pytorch related imports must come after all imports that require cugraph, -# because of context cleanup issues b/w pytorch and cugraph -# See this issue: https://github.com/rapidsai/cugraph/issues/2718 -from .distributed_data_classifier import DomainClassifier, QualityClassifier - __all__ = [ - "DomainClassifier", "ExactDuplicates", "Filter", - "LSH", - "MinHash", "Modify", - "QualityClassifier", "Score", "ScoreFilter", "Sequential", "TaskDecontamination", "AddId", ] + +# GPU packages +try: + from .fuzzy_dedup import LSH, MinHash + + __all__ += ["LSH", "MinHash"] +except ModuleNotFoundError: + pass + +# Pytorch related imports must come after all imports that require cugraph, +# because of context cleanup issues b/w pytorch and cugraph +# See this issue: https://github.com/rapidsai/cugraph/issues/2718 +from .distributed_data_classifier import DomainClassifier, QualityClassifier + +__all__ += ["DomainClassifier", "QualityClassifier"] diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index 5d960ac6e..2831f516f 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -28,7 +28,8 @@ from nemo_curator._compat import DASK_P2P_ERROR from nemo_curator.datasets import DocumentDataset -from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if +from nemo_curator.log import create_logger +from nemo_curator.utils.distributed_utils import performance_report_if from nemo_curator.utils.gpu_utils import is_cudf_type diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 3b0576058..1bb5aa0e3 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -22,9 +22,6 @@ from typing import List, Tuple, Union import cudf -import cugraph -import cugraph.dask as dcg -import cugraph.dask.comms.comms as Comms import cupy as cp import dask_cudf import numpy as np @@ -39,8 +36,12 @@ filter_text_rows_by_bucket_batch, merge_left_to_shuffled_right, ) -from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if -from nemo_curator.utils.distributed_utils import get_current_client, get_num_workers +from nemo_curator.log import create_logger +from nemo_curator.utils.distributed_utils import ( + get_current_client, + get_num_workers, + performance_report_if, +) from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import ( convert_str_id_to_int, int_ids_to_str, @@ -1106,6 +1107,10 @@ def _run_connected_components( deduped_parsed_id_path, output_path, ): + import cugraph.dask as dcg + import cugraph.dask.comms.comms as Comms + from cugraph import MultiGraph + Comms.initialize(p2p=True) df = dask_cudf.read_parquet( deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True @@ -1120,7 +1125,7 @@ def _run_connected_components( df = df[[self.left_id, self.right_id]].astype(np.int64) df = dask_cudf.concat([df, self_edge_df]) - G = cugraph.MultiGraph(directed=False) + G = MultiGraph(directed=False) G.from_dask_cudf_edgelist( df, source=self.left_id, destination=self.right_id, renumber=False ) diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index 71fa1cdca..33cbe1fa2 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -11,20 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os os.environ["RAPIDS_NO_INITIALIZE"] = "1" import warnings +from contextlib import nullcontext from pathlib import Path from typing import Union -import cudf import dask.dataframe as dd -import dask_cudf import pandas as pd -from dask.distributed import Client, LocalCluster, get_worker -from dask_cuda import LocalCUDACluster +from dask.distributed import Client, LocalCluster, get_worker, performance_report + +from nemo_curator.utils.gpu_utils import ( + GPU_INSTALL_STRING, + is_cudf_type, + try_dask_cudf_import_and_raise, +) class DotDict: @@ -48,6 +53,12 @@ def start_dask_gpu_local_cluster(args) -> Client: GPUs present on the machine. """ + try: + from dask_cuda import LocalCUDACluster + except ModuleNotFoundError: + raise ModuleNotFoundError( + f"Starting a GPU cluster requires GPU dependencies. {GPU_INSTALL_STRING}" + ) # Setting conservative defaults # which should work across most systems @@ -166,6 +177,8 @@ def _enable_spilling(): i.e., computing on objects that occupy more memory than is available on the GPU. """ + import cudf + cudf.set_option("spill", True) @@ -184,6 +197,9 @@ def read_single_partition( A cudf DataFrame or a pandas DataFrame. """ + if backend == "cudf": + try_dask_cudf_import_and_raise("Backend=cudf requires GPU packages") + if filetype == "jsonl": read_kwargs = {"lines": True} if backend == "cudf": @@ -265,6 +281,9 @@ def read_data( A Dask-cuDF or a Dask-pandas DataFrame. """ + if backend == "cudf": + try_dask_cudf_import_and_raise("Backend=cudf requires GPU packages") + if file_type == "pickle": df = read_pandas_pickle(input_files[0], add_filename=add_filename) df = dd.from_pandas(df, npartitions=16) @@ -369,10 +388,12 @@ def single_partition_write_with_filename(df, output_file_dir, output_type="jsonl warnings.warn(f"Empty partition found") empty_partition = False - if isinstance(df, pd.DataFrame): - success_ser = pd.Series([empty_partition]) - else: + if is_cudf_type(df): + import cudf + success_ser = cudf.Series([empty_partition]) + else: + success_ser = pd.Series([empty_partition]) if empty_partition: filename = df.filename.iloc[0] @@ -425,10 +446,13 @@ def write_to_disk(df, output_file_dir, write_to_filename=False, output_type="jso ) if write_to_filename: - if isinstance(df, dd.DataFrame): - output_meta = pd.Series([True], dtype="bool") - else: + if is_cudf_type(df): + import cudf + output_meta = cudf.Series([True]) + else: + output_meta = pd.Series([True], dtype="bool") + os.makedirs(output_file_dir, exist_ok=True) output = df.map_partitions( single_partition_write_with_filename, @@ -440,7 +464,7 @@ def write_to_disk(df, output_file_dir, write_to_filename=False, output_type="jso output = output.compute() else: if output_type == "jsonl": - if isinstance(df, dask_cudf.DataFrame): + if is_cudf_type(df): # See open issue here: https://github.com/rapidsai/cudf/issues/15211 # df.to_json(output_file_dir, orient="records", lines=True, engine="cudf", force_ascii=False) df.to_json( @@ -521,3 +545,10 @@ def get_current_client(): return Client.current() except ValueError: return None + + +def performance_report_if(path=None, report_name="dask-profile.html"): + if path is not None: + return performance_report(os.path.join(path, report_name)) + else: + return nullcontext() diff --git a/nemo_curator/utils/gpu_utils.py b/nemo_curator/utils/gpu_utils.py index de1c23dfe..d231f13b5 100644 --- a/nemo_curator/utils/gpu_utils.py +++ b/nemo_curator/utils/gpu_utils.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo_curator[cuda]` +or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda]` if installing from source""" + def is_cudf_type(obj): """ @@ -23,3 +26,16 @@ def is_cudf_type(obj): str(getattr(obj, "_meta", "")), ] return any("cudf" in obj_type for obj_type in types) + + +def try_dask_cudf_import_and_raise(message_prefix: str): + """ + Try to import cudf/dask-cudf and raise an error message on installing dependencies. + Optionally prepends msg + + """ + try: + import cudf + import dask_cudf + except ModuleNotFoundError: + raise ModuleNotFoundError(f"{message_prefix}. {GPU_INSTALL_STRING}")