diff --git a/README.md b/README.md index ead0348f..a6e7e0f8 100644 --- a/README.md +++ b/README.md @@ -39,8 +39,9 @@ NeMo Curator provides a collection of scalable data-mining modules. Some of the - [Document-level deduplication](docs/user-guide/gpudeduplication.rst) - - Both exact and fuzzy (near-identical) deduplication are accelerated using cuDF and Dask + - exact and fuzzy (near-identical) deduplication are accelerated using cuDF and Dask - For fuzzy deduplication, our implementation follows the method described in [Microsoft Turing NLG 530B](https://arxiv.org/abs/2201.11990) + - For semantic deduplication, our implementation follows the method described in [SemDeDup] (https://arxiv.org/pdf/2303.09540) by Meta AI (FAIR) (https://github.com/facebookresearch/SemDeDup) - [Multilingual downstream-task decontamination](docs/user-guide/taskdecontamination.rst) following the approach of [OpenAI GPT3](https://arxiv.org/pdf/2005.14165.pdf) and [Microsoft Turing NLG 530B](https://arxiv.org/abs/2201.11990) diff --git a/config/sem_dedup_config.yaml b/config/sem_dedup_config.yaml new file mode 100644 index 00000000..ec847e4b --- /dev/null +++ b/config/sem_dedup_config.yaml @@ -0,0 +1,32 @@ +# Configuration file for semdantic dedup +cache_dir: "semdedup_cache" +num_files: 16 +id_col_name: "id" +id_col_type: "int" +input_column: "text" + +# Embeddings configuration +embeddings_save_loc: "embeddings" +embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2" +embedding_batch_size: 128 +embedding_max_mem_gb: 25 + +# Clustering configuration +clustering_save_loc: "clustering_results" +n_clusters: 1000 +seed: 1234 +max_iter: 100 +kmeans_with_cos_dist: false + +# Semdedup configuration +which_to_keep: "hard" +largest_cluster_size_to_process: 100000 +sim_metric: "cosine" + +# Extract dedup configuration +eps_thresholds: + - 0.01 + - 0.001 + +# Which threshold to use for extracting deduped data +eps_to_extract: 0.01 diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index 74c219c2..31f29069 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -46,4 +46,3 @@ personalidentifiableinformationidentificationandremoval.rst distributeddataclassification.rst kubernetescurator.rst - diff --git a/examples/semdedup_example.py b/examples/semdedup_example.py new file mode 100644 index 00000000..a1ed163b --- /dev/null +++ b/examples/semdedup_example.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.log import create_logger +from nemo_curator.modules.config import SemDedupConfig +from nemo_curator.modules.semantic_dedup import SemDedup +from nemo_curator.utils.distributed_utils import get_client, read_data +from nemo_curator.utils.file_utils import ( + expand_outdir_and_mkdir, + get_all_files_paths_under, +) +from nemo_curator.utils.script_utils import ArgumentHelper + + +def silence_hf_warnings(): + from transformers.utils import logging + + logging.set_verbosity_error() + + +def main(args): + semdedup_config = SemDedupConfig.from_yaml(args.config_file) + client = get_client(**ArgumentHelper.parse_client_args(args)) + + silence_hf_warnings() + client.run(silence_hf_warnings) + + expand_outdir_and_mkdir(semdedup_config.cache_dir) + logger = create_logger( + rank=0, + name="logger-end-to_end-semdup", + log_file=os.path.join(semdedup_config.cache_dir, "compute_embeddings.log"), + log_level=logging.INFO, + stdout=True, + ) + st = time.time() + input_files = get_all_files_paths_under( + root=args.input_data_dir, + ) + if semdedup_config.num_files > 0: + input_files = input_files[: semdedup_config.num_files] + logger.info(f"Processing {len(input_files)} files") + ddf = read_data( + input_files=input_files, + file_type=args.input_file_type, + add_filename=False, + backend="cudf", + ) + dataset = DocumentDataset(ddf) + semdup = SemDedup(semdedup_config, logger=logger) + dedup_ids = semdup(dataset) + print(dedup_ids.df.head()) + logger.info(f"Time taken: {time.time() - st}") + client.cancel(client.futures, force=True) + client.close() + + +def attach_args(): + parser = ArgumentHelper.parse_semdedup_args(add_input_args=True) + return parser + + +def console_script(): + main(attach_args().parse_args()) + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/nemo_curator/log.py b/nemo_curator/log.py index e69afc1f..92cf37a3 100644 --- a/nemo_curator/log.py +++ b/nemo_curator/log.py @@ -19,7 +19,7 @@ from nemo_curator.utils.file_utils import expand_outdir_and_mkdir -def create_logger(rank, log_file, name="logger", log_level=logging.INFO): +def create_logger(rank, log_file, name="logger", log_level=logging.INFO, stdout=False): # Create the logger logger = logging.getLogger(name) logger.setLevel(log_level) @@ -36,8 +36,12 @@ def create_logger(rank, log_file, name="logger", log_level=logging.INFO): file_handler.setFormatter(formatter) logger.addHandler(file_handler) - logger = logging.LoggerAdapter(logger, extra) + if stdout: + stdout_handler = logging.StreamHandler() + stdout_handler.setFormatter(formatter) + logger.addHandler(stdout_handler) + logger = logging.LoggerAdapter(logger, extra) return logger diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index 8b961326..db6aca7d 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -22,7 +22,7 @@ from nemo_curator.utils.import_utils import gpu_only_import_from from .add_id import AddId -from .config import FuzzyDuplicatesConfig +from .config import FuzzyDuplicatesConfig, SemDedupConfig from .dataset_ops import blend_datasets, Shuffle from .exact_dedup import ExactDuplicates from .filter import Filter, Score, ScoreFilter @@ -36,10 +36,19 @@ FuzzyDuplicates = gpu_only_import_from( "nemo_curator.modules.fuzzy_dedup", "FuzzyDuplicates" ) - # Pytorch related imports must come after all imports that require cugraph, # because of context cleanup issues b/w pytorch and cugraph # See this issue: https://github.com/rapidsai/cugraph/issues/2718 +SemDedup = gpu_only_import_from("nemo_curator.modules.semantic_dedup", "SemDedup") +EmbeddingCreator = gpu_only_import_from( + "nemo_curator.modules.semantic_dedup", "EmbeddingCreator" +) +ClusteringModel = gpu_only_import_from( + "nemo_curator.modules.semantic_dedup", "ClusteringModel" +) +SemanticClusterLevelDedup = gpu_only_import_from( + "nemo_curator.modules.semantic_dedup", "SemanticClusterLevelDedup" +) from .distributed_data_classifier import DomainClassifier, QualityClassifier __all__ = [ @@ -59,4 +68,9 @@ "AddId", "blend_datasets", "Shuffle", + "SemDedup", + "SemDedupConfig", + "EmbeddingCreator", + "ClusteringModel", + "SemanticClusterLevelDedup", ] diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index 45ea527f..eec5b42e 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -13,7 +13,8 @@ # limitations under the License. import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List import yaml @@ -98,3 +99,70 @@ def __post_init__(self): raise ValueError("Jaccard Threshold must be between [0,1]") if self.buckets_per_shuffle <= 0: raise ValueError("Buckets per shuffle must be greater than 0") + + +@dataclass +class SemDedupConfig(BaseConfig): + """ + Configuration for Semantic Deduplication. + + Attributes: + cache_dir (str): Directory to store cache. + num_files (int): Number of files. Default is -1, meaning all files. + id_col_name (str): Column name for ID. + id_col_type (str): Column type for ID. + input_column (str): Input column for embeddings. + embeddings_save_loc (str): Location to save embeddings. + embedding_model_name_or_path (str): Model name or path for embeddings. + embedding_batch_size (int): Inital Batch size for processing embeddings. + embedding_max_mem_gb (int): Maximum memory in GB for embeddings. + clustering_save_loc (str): Location to save clustering results. + n_clusters (int): Number of clusters. + seed (int): Seed for clustering. + max_iter (int): Maximum iterations for clustering. + kmeans_with_cos_dist (bool): Use KMeans with cosine distance. + which_to_keep (str): Which duplicates to keep. + largest_cluster_size_to_process (int): Largest cluster size to process. + sim_metric (str): Similarity metric for deduplication. + eps_thresholds (List[float]): Epsilon thresholds to calculate if semantically similar or not. + eps_to_extract (float): Epsilon value to extract deduplicated data. + """ + + cache_dir: str + num_files: int = -1 + id_col_name: str = "id" + id_col_type: str = "str" + input_column: str = "text" + + # Embeddings + embeddings_save_loc: str = "embeddings" + embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2" + embedding_batch_size: int = 128 + embedding_max_mem_gb: int = 25 + + # Clustering config + clustering_save_loc: str = "clustering_results" + n_clusters: int = 1000 + seed: int = 1234 + max_iter: int = 100 + kmeans_with_cos_dist: bool = False + + # Semdedup config + which_to_keep: str = "hard" + largest_cluster_size_to_process: int = 100000 + sim_metric: str = "cosine" + + # Extract dedup config + eps_thresholds: List[float] = field(default_factory=lambda: [0.01, 0.001]) + eps_to_extract: float = 0.01 + + def __post_init__(self): + if self.cache_dir is None: + raise ValueError( + "Finding sem-dedup requires a cache directory accessible via all workers to store intermediates" + ) + + if self.eps_to_extract not in self.eps_thresholds: + raise ValueError( + f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds {self.eps_thresholds}" + ) diff --git a/nemo_curator/modules/semantic_dedup.py b/nemo_curator/modules/semantic_dedup.py new file mode 100644 index 00000000..5b95692f --- /dev/null +++ b/nemo_curator/modules/semantic_dedup.py @@ -0,0 +1,573 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import shutil +from dataclasses import dataclass +from typing import List, Optional, Union + +import cudf +import cupy as cp +import dask.bag as db +import dask.dataframe as dd +import dask_cudf +import numpy as np +import torch +import torch.nn as nn +from crossfit import op +from crossfit.backend.torch.hf.model import HFModel +from cuml.dask.cluster import KMeans +from torch.nn import functional as F +from transformers import AutoConfig, AutoModel, AutoTokenizer + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.log import create_logger +from nemo_curator.modules.config import SemDedupConfig +from nemo_curator.utils.distributed_utils import write_to_disk +from nemo_curator.utils.file_utils import expand_outdir_and_mkdir +from nemo_curator.utils.semdedup_utils import ( + _assign_and_sort_clusters, + extract_dedup_data, + get_semantic_matches_per_cluster, +) + + +# Embedding Creation Module +@dataclass +class EmbeddingConfig: + model_name_or_path: str + max_mem_gb: int + max_seq_length: int = None + + def __post_init__(self): + self.max_seq_length = AutoTokenizer.from_pretrained( + self.model_name_or_path + ).model_max_length + # Gaurd against the HF bug + # which sets max_seq_length to max(int) for some models + if self.max_seq_length > 1e5: + self.max_seq_length = AutoConfig.from_pretrained( + self.model_name_or_path + ).max_position_embeddings + + +class EmbeddingPytorchModel(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.model = AutoModel.from_pretrained( + config.model_name_or_path, config=self.config, force_download=False + ) + + def feature(self, input_ids, attention_mask): + with torch.autocast(device_type=input_ids.device.type): + embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask) + return embeddings + + @torch.no_grad() + def forward(self, batch): + feature = self.feature(batch["input_ids"], batch["attention_mask"]) + return self._mean_pooling(feature, batch["attention_mask"]) + + def _mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1) + sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) + return F.normalize(sum_embeddings / sum_mask, dim=1) + + +class EmbeddingCrossFitModel(HFModel): + def __init__(self, config: EmbeddingConfig): + self.config = config + super().__init__( + self.config.model_name_or_path, max_mem_gb=self.config.max_mem_gb + ) + + def load_model(self, device="cuda"): + model = EmbeddingPytorchModel(self.config) + model = model.to(device) + model.eval() + return model + + def max_seq_length(self): + return self.config.max_seq_length + + def load_config(self): + return AutoConfig.from_pretrained(self.config.model_name_or_path) + + def load_tokenizer(self): + return AutoTokenizer.from_pretrained(self.config.model_name_or_path) + + +class EmbeddingCreator: + def __init__( + self, + embedding_model_name_or_path: str, + embedding_max_mem_gb: str, + embedding_batch_size: int, + embedding_output_dir: str, + input_column: str = "text", + write_embeddings_to_disk: bool = True, + write_to_filename: bool = False, + logger: Union[logging.Logger, str] = "./", + ): + """ + Initializes an EmbeddingCreator for generating embeddings using the specified model configurations. + + Args: + embedding_model_name_or_path (str): The path or identifier for the model used to generate embeddings. + embedding_max_mem_gb (str): Maximum memory usage for the embedding process. + embedding_batch_size (int): Number of samples to process in each batch. + embedding_output_dir (str): Directory path where embeddings will be saved. + input_column (str): Column name from the data to be used for embedding generation, defaults to "text". + write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True. + We recommend setting this to False when you have a delayed pipeline. + Setting it to False can lead to more memory overhead. + write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False. + logger (Union[logging.Logger, str]): Logger object or path to store logs, defaults to "./". + + Attributes: + embeddings_config (EmbeddingConfig): Configuration for embeddings. + batch_size (int): Batch size for embedding generation. + logger (logging.Logger): Logger instance for the class. + embedding_output_dir (str): Output directory for embeddings. + input_column (str): Input column for data processing. + model (EmbeddingCrossFitModel): Model instance for embedding generation. + write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False. + """ + + self.embeddings_config = EmbeddingConfig( + model_name_or_path=embedding_model_name_or_path, + max_mem_gb=embedding_max_mem_gb, + ) + self.batch_size = embedding_batch_size + self.logger = self._setup_logger(logger) + self.embedding_output_dir = embedding_output_dir + self.input_column = input_column + self.model = EmbeddingCrossFitModel(self.embeddings_config) + self.write_embeddings_to_disk = write_embeddings_to_disk + self.write_to_filename = write_to_filename + + def _setup_logger(self, logger): + if isinstance(logger, str): + return create_logger( + rank=0, + name="compute-embeddings", + log_file=os.path.join(logger, "compute_embeddings.log"), + log_level=logging.INFO, + stdout=True, + ) + else: + return logger + + def create_embeddings( + self, ddf: dask_cudf.DataFrame, input_column="text" + ) -> dask_cudf.DataFrame: + pipe = op.Sequential( + op.Tokenizer( + self.model, + cols=[input_column], + tokenizer_type="sentencepiece", + max_length=self.embeddings_config.max_seq_length, + ), + op.Predictor( + self.model, + sorted_data_loader=True, + batch_size=self.batch_size, + pred_output_col="embeddings", + ), + keep_cols=ddf.columns.tolist(), + ) + return pipe(ddf) + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + embedding_ddf = self.create_embeddings(dataset.df, self.input_column) + if self.write_embeddings_to_disk: + write_to_disk( + embedding_ddf, + self.embedding_output_dir, + write_to_filename=self.write_to_filename, + output_type="parquet", + ) + return DocumentDataset( + dask_cudf.read_parquet( + self.embedding_output_dir, blocksize="2GB", aggregate_files=True + ) + ) + else: + return DocumentDataset(embedding_ddf) + + +### Clustering Module +def get_embedding_ar(df: "cudf.DataFrame") -> cp.ndarray: + return df["embeddings"].list.leaves.values.reshape(len(df), -1) + + +def add_dist_to_cents(df: "cudf.DataFrame", centroids: cp.ndarray) -> "cudf.DataFrame": + embed_array = get_embedding_ar(df) + centroids_ar = centroids[df["nearest_cent"].values] + dist_to_cents = cp.sqrt(np.sum((embed_array - centroids_ar) ** 2, axis=1)) + df["dist_to_cent"] = dist_to_cents + return df + + +class ClusteringModel: + def __init__( + self, + id_col: str, + max_iter: int, + n_clusters: int, + clustering_output_dir: str, + sim_metric: str = "cosine", + which_to_keep: str = "hard", + sort_clusters: bool = True, + kmeans_with_cos_dist: bool = False, + partition_size: str = "2gb", + logger: Union[logging.Logger, str] = "./", + ): + """ + Initializes the ClusteringModel with the provided settings for semantic clustering to help semantic deduplication. + + Args: + id_col (str): Column name used as the identifier in the dataset. + max_iter (int): Maximum number of iterations for the clustering algorithm. + n_clusters (int): The number of clusters to form. + clustering_output_dir (str): Directory path where clustering results will be saved. + sim_metric (str): Similarity metric to use for clustering, default is "cosine". + which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard". + sort_clusters (bool): Whether to sort clusters, default is True. + kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False. + partition_size (str): The size of data partition to run kmeans with, default is "2gb". + logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./". + + This constructor sets up the parameters required for clustering operations. + """ + self.id_col = id_col + self.max_iter = max_iter + self.n_clusters = n_clusters + self.clustering_output_dir = clustering_output_dir + self.sim_metric = sim_metric + self.keep_hard = which_to_keep == "hard" + self.kmeans_with_cos_dist = kmeans_with_cos_dist + self.partition_size = partition_size + self.sort_clusters = sort_clusters + self.logger = self._setup_logger(logger) + + if not os.path.exists(self.clustering_output_dir): + expand_outdir_and_mkdir(self.clustering_output_dir) + else: + self.logger.warning( + f"Clustering output directory {self.clustering_output_dir} already exists and will be overwritten" + ) + + def _setup_logger(self, logger): + if isinstance(logger, str): + return create_logger( + rank=0, + name="SemanticClusterLevelDedup", + log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"), + log_level=logging.INFO, + stdout=True, + ) + else: + return logger + + def __call__(self, embeddings_dataset: DocumentDataset): + embeddings_df = embeddings_dataset.df + + assert "embeddings" in embeddings_df.columns + embeddings_df = embeddings_df[[self.id_col, "embeddings"]] + + embeddings_df = embeddings_df.to_backend("pandas").persist() + embeddings_df = embeddings_df.repartition(partition_size=self.partition_size) + embeddings_df = embeddings_df.to_backend("cudf") + + cupy_darr = embeddings_df.map_partitions( + get_embedding_ar, meta=cp.ndarray([1, 1]) + ) + cupy_darr.compute_chunk_sizes() + + kmeans = KMeans(n_clusters=self.n_clusters, max_iter=self.max_iter) + self.logger.info("KMeans starting fit") + kmeans.fit(cupy_darr) + self.logger.info("KMeans fit complete") + + self.logger.info( + "Computing nearest centroids + distance to centers using kmeans.predict" + ) + nearest_cents = kmeans.predict(cupy_darr) + embeddings_df["nearest_cent"] = nearest_cents.astype(np.int32) + del nearest_cents + meta_df = embeddings_df._meta.copy() + meta_df["dist_to_cent"] = cp.zeros(1) + embeddings_df = embeddings_df.map_partitions( + add_dist_to_cents, centroids=kmeans.cluster_centers_, meta=meta_df + ) + centroids = kmeans.cluster_centers_ + embeddings_df = embeddings_df.reset_index(drop=True) + kmeans_centroids_file = os.path.join( + self.clustering_output_dir, "kmeans_centroids.npy" + ) + np.save(kmeans_centroids_file, centroids) + self.logger.info("Saving centroids complete") + del kmeans, cupy_darr, centroids + + clustering_output_dir = os.path.join( + self.clustering_output_dir, "embs_by_nearest_center" + ) + if os.path.exists(clustering_output_dir): + self.logger.warning( + f"Output directory {clustering_output_dir} already exists and will be overwritten" + ) + shutil.rmtree(clustering_output_dir) + + embeddings_df.to_parquet( + clustering_output_dir, + index=False, + partition_on="nearest_cent", + ) + self.logger.info( + f"Saved embeddings by nearest center to {clustering_output_dir}" + ) + del embeddings_df + + if self.sort_clusters: + _assign_and_sort_clusters( + id_col=self.id_col, + kmeans_centroids_file=kmeans_centroids_file, + nearest_cent_dir=clustering_output_dir, + output_sorted_clusters_dir=os.path.join( + self.clustering_output_dir, "sorted" + ), + sim_metric=self.sim_metric, + keep_hard=self.keep_hard, + kmeans_with_cos_dist=self.kmeans_with_cos_dist, + cluster_ids=range(self.n_clusters), + logger=self.logger, + ) + + fps = [ + os.path.join(clustering_output_dir, file_name) + for file_name in os.listdir(clustering_output_dir) + ] + embeddings_df = dd.from_map(cudf.read_parquet, fps) + return DocumentDataset(embeddings_df) + + +class SemanticClusterLevelDedup: + def __init__( + self, + n_clusters: int, + emb_by_clust_dir: str, + sorted_clusters_dir: str, + id_col: str, + id_col_type: str, + which_to_keep: str, + output_dir: str, + logger: Union[logging.Logger, str] = "./", + ) -> None: + """ + Initialize the SemanticClusterLevelDedup class. + + Args: + n_clusters (int): Number of clusters. + emb_by_clust_dir (str): Directory containing embeddings by cluster. + sorted_clusters_dir (str): Directory containing sorted clusters. + id_col (str): Column name for IDs. + id_col_type (str): Data type of the ID column. + which_to_keep (str): Strategy for which duplicate to keep. + output_dir (str): Directory to save output files. + logger (Union[logging.Logger, str]): Logger instance or path to the log file directory. + """ + self.n_clusters = n_clusters + self.emb_by_clust_dir = emb_by_clust_dir + self.sorted_clusters_dir = sorted_clusters_dir + self.id_col = id_col + self.id_col_type = id_col_type + self.which_to_keep = which_to_keep + self.output_dir = output_dir + self.semdedup_pruning_tables_dir = os.path.join( + output_dir, "semdedup_pruning_tables" + ) + self.computed_semantic_match_dfs = False + self.logger = self._setup_logger(logger) + + def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger: + """ + Set up the logger. + + Args: + logger (Union[logging.Logger, str]): Logger instance or path to the log file directory. + + Returns: + logging.Logger: Configured logger. + """ + if isinstance(logger, str): + return create_logger( + rank=0, + name="SemanticClusterLevelDedup", + log_file=os.path.join(logger, "SemanticClusterLevelDedup.log"), + log_level=logging.INFO, + stdout=True, + ) + else: + return logger + + def compute_semantic_match_dfs( + self, eps_list: Optional[List[float]] = None + ) -> None: + """ + Compute semantic match dataframes for clusters. + + Args: + eps_list (Optional[List[float]]): List of epsilon values for clustering. + """ + if eps_list is None: + eps_list1 = [1.0e-2, 1.0e-3, 1.0e-4, 1.0e-5, 1.0e-6] + eps_list2 = [0.1 + x * 0.005 for x in range(34)] + eps_list = eps_list1 + eps_list2 + + if os.path.exists(self.semdedup_pruning_tables_dir): + self.logger.info( + f"Removing existing directory {self.semdedup_pruning_tables_dir}" + ) + shutil.rmtree(self.semdedup_pruning_tables_dir) + expand_outdir_and_mkdir(self.semdedup_pruning_tables_dir) + + tasks = db.from_sequence( + list(range(self.n_clusters)), npartitions=self.n_clusters + ).map( + lambda cluster_id: get_semantic_matches_per_cluster( + cluster_id=cluster_id, + emb_by_clust_dir=self.emb_by_clust_dir, + sorted_clusters_dir=self.sorted_clusters_dir, + id_col=self.id_col, + id_col_type=self.id_col_type, + eps_list=eps_list, + output_dir=self.semdedup_pruning_tables_dir, + which_to_keep=self.which_to_keep, + ) + ) + tasks.compute() + self.computed_semantic_match_dfs = True + + def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset: + """ + Extract deduplicated data based on epsilon value. + + Args: + eps_to_extract (float): Epsilon threshold for extracting deduplicated data. + + Returns: + DocumentDataset: Dataset containing deduplicated documents. + """ + if not self.computed_semantic_match_dfs: + raise ValueError( + "Run compute_semantic_match_dfs before calling extract_dedup_data" + ) + + output_summary_file = os.path.join( + self.output_dir, f"dedup_summary_{eps_to_extract}.csv" + ) + output_parquet_path = os.path.join( + self.output_dir, f"unique_ids_{eps_to_extract}.parquet" + ) + extract_dedup_data( + eps=eps_to_extract, + n_clusters=self.n_clusters, + id_col=self.id_col, + id_col_type=self.id_col_type, + sorted_clusters_dir=self.sorted_clusters_dir, + semdedup_pruning_tables_dir=self.semdedup_pruning_tables_dir, + output_summary_file=output_summary_file, + output_parquet_path=output_parquet_path, + logger=self.logger, + ) + + fps = [ + os.path.join(output_parquet_path, file_name) + for file_name in os.listdir(output_parquet_path) + ] + return DocumentDataset.read_parquet(fps, backend="cudf") + + +class SemDedup: + def __init__( + self, + config: SemDedupConfig, + logger: Union[logging.Logger, str] = "./", + ) -> None: + """ + Initialize the SemDedup class. + + Args: + config (SemDedupConfig): Configuration for SemDedup. + logger (Union[logging.Logger, str]): Logger instance or path to the log file directory. + """ + self.config = config + self.logger = logger + cache_dir = config.cache_dir + self.embedding_creator = EmbeddingCreator( + embedding_model_name_or_path=config.embedding_model_name_or_path, + max_memory=config.embedding_max_mem_gb, + batch_size=config.embedding_batch_size, + input_column=config.input_column, + embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc), + logger=logger, + ) + self.clustering_model = ClusteringModel( + id_col=config.id_col_name, + max_iter=config.max_iter, + n_clusters=config.n_clusters, + clustering_output_dir=os.path.join(cache_dir, config.clustering_save_loc), + logger=logger, + ) + self.semantic_cluster_dedup = SemanticClusterLevelDedup( + n_clusters=config.n_clusters, + emb_by_clust_dir=os.path.join( + cache_dir, config.clustering_save_loc, "embs_by_nearest_center" + ), + sorted_clusters_dir=os.path.join( + cache_dir, config.clustering_save_loc, "sorted" + ), + id_col=config.id_col_name, + id_col_type=config.id_col_type, + which_to_keep=config.which_to_keep, + output_dir=os.path.join(cache_dir, config.clustering_save_loc), + logger=logger, + ) + self.eps_thresholds = config.eps_thresholds + self.eps_to_extract = config.eps_to_extract + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + """ + Execute the SemDedup process. + + Args: + dataset (DocumentDataset): Input dataset for deduplication. + + Returns: + DocumentDataset: Deduplicated dataset. + """ + embeddings_dataset = self.embedding_creator(dataset) + self.clustering_model(embeddings_dataset) + self.semantic_cluster_dedup.compute_semantic_match_dfs(self.eps_thresholds) + return self.semantic_cluster_dedup.extract_dedup_data( + eps_to_extract=self.eps_to_extract + ) diff --git a/nemo_curator/scripts/semdedup/README.md b/nemo_curator/scripts/semdedup/README.md new file mode 100644 index 00000000..68bc1680 --- /dev/null +++ b/nemo_curator/scripts/semdedup/README.md @@ -0,0 +1,40 @@ +# SemDeDup Pipeline + +This pipeline is used to cluster and deduplicate data points based on their embeddings. +Please edit "semdedup_config.yaml" to configure the pipeline and run it using the following commands. + + +## Pipeline Steps + +1) Modify "semdedup_config.yaml" + +2) Compute embeddings: + ```sh + python compute_embeddings.py --input-data-dir "$INPUT_DATA_DIR" --input-file-type "jsonl" --input-file-extension "json" --config-file "$CONFIG_FILE" + ``` + **Input:** `config.embeddings.input_data_dir/*.jsonl` and output from step (2) + **Output:** Embedding parquet files in the embedding directory + +3) Clustering + ```sh + python clustering.py --config-file "$CONFIG_FILE" + ``` + **Input:** Output from step (3) + + **Output:** Under `{config.cache_dir}/{config.clustering_save_loc}` directory, including: + + - `kmeans_centroids.npy` + - `embs_by_nearest_center` directory, containing `nearest_cent={x}` where x ranges from 0 to `num_clusters - 1` + - Parquet files within `embs_by_nearest_center/nearest_cent={x}` containing the data points in each cluster + + +3) Extract deduplicated data + ```sh + python extract_dedup_data.py --config-file "$CONFIG_FILE" + ``` + **Input:** Output from step (3) + **Output:** `{config.cache_dir}/{config.clustering_save_loc}/unique_ids_{}.parquet` + +## End to End Script + +python3 end_to_end_example.py --input-data-dir "$INPUT_DATA_DIR" --input-file-type "jsonl" --config-file "$CONFIG_FILE" diff --git a/nemo_curator/scripts/semdedup/__init__.py b/nemo_curator/scripts/semdedup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nemo_curator/scripts/semdedup/clustering.py b/nemo_curator/scripts/semdedup/clustering.py new file mode 100644 index 00000000..82b83c54 --- /dev/null +++ b/nemo_curator/scripts/semdedup/clustering.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from datetime import datetime + +os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" +import dask_cudf + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.log import create_logger +from nemo_curator.modules.config import SemDedupConfig +from nemo_curator.modules.semantic_dedup import ClusteringModel +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.file_utils import expand_outdir_and_mkdir +from nemo_curator.utils.script_utils import ArgumentHelper + + +def main(args): + semdedup_config = SemDedupConfig.from_yaml(args.config_file) + client = get_client(**ArgumentHelper.parse_client_args(args)) + save_folder = os.path.join( + semdedup_config.cache_dir, semdedup_config.clustering_save_loc + ) + expand_outdir_and_mkdir(save_folder) + # Initialize logger + log_file = os.path.join(save_folder, "compute_centroids.log") + + logger = create_logger( + rank=0, + log_file=log_file, + log_level=logging.INFO, + name="logger-compute-centroids", + stdout=True, + ) + + client = get_client(**ArgumentHelper.parse_client_args(args)) + dt1 = datetime.now() + print("Start time:", dt1) + + embedding_fp = os.path.join( + semdedup_config.cache_dir, semdedup_config.embeddings_save_loc + ) + clustering_output_dir = os.path.join( + semdedup_config.cache_dir, semdedup_config.clustering_save_loc + ) + # Switch to https://github.com/NVIDIA/NeMo-Curator/issues/50 + # When we fix that + embedding_df = dask_cudf.read_parquet(embedding_fp, blocksize="2GB") + embedding_dataset = DocumentDataset(embedding_df) + + clustering_model = ClusteringModel( + id_col=semdedup_config.id_col_name, + max_iter=semdedup_config.max_iter, + n_clusters=semdedup_config.n_clusters, + clustering_output_dir=clustering_output_dir, + logger=logger, + ) + clustered_embeddings = clustering_model(embedding_dataset) + clustered_embeddings.df.head(10) + dt2 = datetime.now() + elapse = dt2 - dt1 + print("End time:", dt2) + print("elapse:", elapse) + + client.cancel(client.futures, force=True) + client.close() + + +def attach_args(): + parser = ArgumentHelper.parse_semdedup_args( + description=( + "Performs clustering on the computed embeddings of a collection of documents. " + "This script requires that the embeddings have been created beforehand using: " + "semdedup_extract_embeddings" + "Input arguments include: " + "--config-file for the path to the semdedup config file. " + "Important configuration parameters include: " + " cache_dir for the directory to store cache," + " clustering_save_loc for the location to save clustering results," + " n_clusters for the number of clusters," + " seed for the seed for clustering," + " max_iter for the maximum iterations for clustering," + " kmeans_with_cos_dist for using KMeans with cosine distance," + ), + add_input_args=False, + ) + return parser + + +def console_script(): + main(attach_args().parse_args()) + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/semdedup/compute_embeddings.py b/nemo_curator/scripts/semdedup/compute_embeddings.py new file mode 100644 index 00000000..b96c8d38 --- /dev/null +++ b/nemo_curator/scripts/semdedup/compute_embeddings.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.log import create_logger +from nemo_curator.modules.config import SemDedupConfig +from nemo_curator.modules.semantic_dedup import EmbeddingCreator +from nemo_curator.utils.distributed_utils import get_client, read_data +from nemo_curator.utils.file_utils import expand_outdir_and_mkdir, get_remaining_files +from nemo_curator.utils.script_utils import ArgumentHelper + + +def main(args): + semdedup_config = SemDedupConfig.from_yaml(args.config_file) + client = get_client(**ArgumentHelper.parse_client_args(args)) + expand_outdir_and_mkdir(semdedup_config.cache_dir) + logger = create_logger( + rank=0, + name="logger-compute-embeddings", + log_file=os.path.join(semdedup_config.cache_dir, "compute_embeddings.log"), + log_level=logging.INFO, + stdout=True, + ) + + output_data_dir = os.path.join( + semdedup_config.cache_dir, semdedup_config.embeddings_save_loc + ) + # Some time jsonl files are stored as .json + # So to handle that case we can pass the input_file_extension + if args.input_file_extension is not None: + input_file_extension = args.input_file_extension + else: + input_file_extension = args.input_file_type + print("input_file_extension", input_file_extension) + st = time.time() + input_files = get_remaining_files( + input_file_path=args.input_data_dir, + output_file_path=output_data_dir, + input_file_type=input_file_extension, + num_files=semdedup_config.num_files, + ) + logger.info(f"Processing {len(input_files)} files") + if len(input_files) == 0: + logger.info("No files to process") + return + + ddf = read_data( + input_files=input_files, file_type=args.input_file_type, add_filename=False + ) + ddf = ddf.reset_index(drop=True) + dataset = DocumentDataset(ddf) + # Can repartition here if needed + # ddf = ddf.repartition(partition_size="64MB") + embedding_creator = EmbeddingCreator( + embedding_model_name_or_path=semdedup_config.embedding_model_name_or_path, + embedding_max_mem_gb=semdedup_config.embedding_max_mem_gb, + embedding_batch_size=semdedup_config.embedding_batch_size, + embedding_output_dir=os.path.join( + semdedup_config.cache_dir, semdedup_config.embeddings_save_loc + ), + input_column=semdedup_config.input_column, + logger=logger, + write_to_filename=False, + ) + embedding_dataset = embedding_creator(dataset=dataset) + print(embedding_dataset.df.head()) + logger.info(f"Time taken: {time.time() - st}") + client.cancel(client.futures, force=True) + client.close() + + +def attach_args(): + parser = ArgumentHelper.parse_semdedup_args( + description=( + "Computes the embeddings of a collection of documents using the specified model. " + "The model is specified in the config file using embedding_model_name_or_path (e.g. 'sentence-transformers/paraphrase-MiniLM-L6-v2'). " + "The embeddings are saved in the specified cache directory under the embeddings_save_loc directory. " + "Input arguments include: " + "--input_data_dir for the directory containing input data files, " + "--input_file_extension for specifying the file extension of input files (e.g., .jsonl), " + "--input_file_type for the type of input files (e.g., json, csv), " + "--input_text_field for the field in the input files containing the text data to be embedded. " + "Additional configuration can be provided via the --config-file argument. " + "Important configuration parameters include: " + " cache_dir for the directory to store cache" + " num_files for the number of files to process (default is -1, meaning all files)," + " input_column for specifying the input column for embeddings," + " embeddings_save_loc for the location to save embeddings," + " embedding_model_name_or_path for the model name or path for embeddings," + " embedding_batch_size for the batch size for processing embeddings," + " embedding_max_mem_gb for the maximum memory in GB for embeddings" + ), + add_input_args=True, + ) + return parser + + +def console_script(): + main(attach_args().parse_args()) + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/semdedup/extract_dedup_data.py b/nemo_curator/scripts/semdedup/extract_dedup_data.py new file mode 100755 index 00000000..ca5016b9 --- /dev/null +++ b/nemo_curator/scripts/semdedup/extract_dedup_data.py @@ -0,0 +1,88 @@ +import logging +import os +from datetime import datetime + +from nemo_curator.log import create_logger +from nemo_curator.modules.config import SemDedupConfig +from nemo_curator.modules.semantic_dedup import SemanticClusterLevelDedup +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.script_utils import ArgumentHelper + + +def main(args): + semdedup_config = SemDedupConfig.from_yaml(args.config_file) + client = get_client(**ArgumentHelper.parse_client_args(args)) + + root = semdedup_config.cache_dir + save_loc = semdedup_config.clustering_save_loc + client = get_client(**ArgumentHelper.parse_client_args(args)) + + logger = create_logger( + rank=0, + log_file=os.path.join(root, save_loc, "extract_dedup_data.log"), + name="logger-extract-dedup-data", + log_level=logging.INFO, + stdout=True, + ) + + dt1 = datetime.now() + logger.info(f"Start: {dt1}") + cache_dir = semdedup_config.cache_dir + semantic_dedup = SemanticClusterLevelDedup( + n_clusters=semdedup_config.n_clusters, + emb_by_clust_dir=os.path.join( + cache_dir, semdedup_config.clustering_save_loc, "embs_by_nearest_center" + ), + sorted_clusters_dir=os.path.join( + cache_dir, semdedup_config.clustering_save_loc, "sorted" + ), + id_col=semdedup_config.id_col_name, + id_col_type=semdedup_config.id_col_type, + which_to_keep=semdedup_config.which_to_keep, + output_dir=os.path.join( + semdedup_config.cache_dir, semdedup_config.clustering_save_loc + ), + logger=logger, + ) + + semantic_dedup.compute_semantic_match_dfs() + for eps in semdedup_config.eps_thresholds: + dedup_id_dataset = semantic_dedup.extract_dedup_data(eps_to_extract=eps) + print(dedup_id_dataset.df.head(10)) + + dt2 = datetime.now() + logger.info(f"End: {dt2}") + elapse = (dt2 - dt1).total_seconds() / 60 + logger.info(f"elapse: {elapse}") + + client.cancel(client.futures, force=True) + client.close() + + +def attach_args(): + parser = ArgumentHelper.parse_semdedup_args( + description=( + "Extracts deduplicated data from the clustered embeddings of a collection of documents. " + "This script requires that embeddings and clustering have been performed beforehand using the specified configurations. " + "earlier using semdedup_extract_embeddings and semdedup_cluster_embeddings." + "Input arguments include: " + "--config-file for the path to the semdedup config file. " + "Important configuration parameters include:" + "- cache_dir for the directory to store cache" + "which_to_keep for specifying which duplicates to keep," + "largest_cluster_size_to_process for the largest cluster size to process," + "sim_metric for the similarity metric for deduplication," + "eps_thresholds for epsilon thresholds to calculate if semantically similar or not" + "and eps_to_extract for the epsilon value to extract deduplicated data." + ), + add_input_args=False, + ) + return parser + + +def console_script(): + main(attach_args().parse_args()) + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index c7769c4d..629cc387 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -17,12 +17,14 @@ import os os.environ["RAPIDS_NO_INITIALIZE"] = "1" +import random import warnings from contextlib import nullcontext from pathlib import Path from typing import Union import dask.dataframe as dd +import numpy as np import pandas as pd from dask.distributed import Client, LocalCluster, get_worker, performance_report @@ -216,8 +218,8 @@ def read_single_partition( " file formats.." ) - if filetype == "jsonl": - read_kwargs = {"lines": True} + if filetype in ["jsonl", "json"]: + read_kwargs = {"lines": filetype == "jsonl"} if backend == "cudf": read_f = cudf.read_json else: @@ -315,7 +317,7 @@ def read_data( if backend == "cudf": df = df.to_backend("cudf") - elif file_type in ["jsonl", "parquet"]: + elif file_type in ["json", "jsonl", "parquet"]: print(f"Reading {len(input_files)} files", flush=True) input_files = sorted(input_files) if files_per_partition > 1: @@ -583,3 +585,29 @@ def performance_report_if(path=None, report_name="dask-profile.html"): return performance_report(os.path.join(path, report_name)) else: return nullcontext() + + +def seed_all(seed: int = 42): + """ + Function to set seed for random number generators for reproducibility. + + Args: + seed: The seed value to use for random number generators. Default is 42. + + Returns: + None + """ + ## Imporing torch to help with context issues + import torch + + # Set seed values for various random number generators + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # Ensure deterministic behavior for CUDA algorithms + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index 3ec466b4..de5c78af 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -63,7 +63,9 @@ def get_all_files_paths_under(root, recurse_subdirectories=True, followlinks=Fal # can lead to problems when there is an error while # writing a file we can use the offset counter approach # in jaccard shuffle as a more robust way to restart jobs -def get_remaining_files(input_file_path, output_file_path, input_file_type): +def get_remaining_files( + input_file_path, output_file_path, input_file_type, num_files=-1 +): """ This function returns a list of the files that still remain to be read. @@ -71,12 +73,16 @@ def get_remaining_files(input_file_path, output_file_path, input_file_type): input_file_path: The path of the input files. output_file_path: The path of the output files. input_file_type: The type of the input files. + num_files: The max number of files to be returned. If -1, all files are returned. Returns: A list of files that still remain to be read. """ if input_file_type == "pickle": return [input_file_path] + + if not os.path.exists(output_file_path): + expand_outdir_and_mkdir(output_file_path) completed_files = [ os.path.basename(entry.path) for entry in os.scandir(output_file_path) ] @@ -86,7 +92,16 @@ def get_remaining_files(input_file_path, output_file_path, input_file_type): for entry in os.scandir(input_file_path) if os.path.basename(entry.path) not in completed_files ] + # Gaurd against non extension files if present in the input directory + input_files = [f for f in input_files if f.endswith(input_file_type)] input_files.sort() + + len_written_files = len(completed_files) + if num_files > 0: + left_to_sample = max(num_files - len_written_files, 0) + else: + left_to_sample = len(input_files) + input_files = input_files[:left_to_sample] return input_files diff --git a/nemo_curator/utils/script_utils.py b/nemo_curator/utils/script_utils.py index 32582daf..0d627257 100644 --- a/nemo_curator/utils/script_utils.py +++ b/nemo_curator/utils/script_utils.py @@ -89,6 +89,7 @@ def add_arg_log_dir(self, default: str): def add_arg_input_data_dir( self, + required=False, help: str = "Input directory consisting of .jsonl files that are accessible " "to all nodes. Use this for a distributed file system", ): @@ -96,12 +97,14 @@ def add_arg_input_data_dir( "--input-data-dir", type=str, default=None, + required=required, help=help, ) def add_arg_input_file_type( self, choices=None, + required=False, help="File type of the dataset to be read in. Supported file formats " "include 'jsonl' (default), 'pickle', or 'parquet'.", ): @@ -109,10 +112,22 @@ def add_arg_input_file_type( "--input-file-type", type=str, default="jsonl", + required=required, choices=choices, help=help, ) + def add_arg_input_file_extension( + self, + help: str = "The file extension of the input files. If not provided, the input file type will be used.", + ): + self.parser.add_argument( + "--input-file-extension", + type=str, + default=None, + help=help, + ) + def add_arg_input_local_data_dir(self): self.parser.add_argument( "--input-local-data-dir", @@ -496,3 +511,38 @@ def parse_gpu_dedup_args(description: str) -> argparse.ArgumentParser: ) return argumentHelper.parser + + @staticmethod + def parse_semdedup_args( + add_input_args=False, + description="Default argument parser for semantic deduplication", + ) -> argparse.ArgumentParser: + """ + Adds default set of arguments that are common to multiple stages of the semantic deduplication pipeline + of the pipeline + """ + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=description, + ) + argumentHelper = ArgumentHelper(parser) + argumentHelper.add_distributed_args() + if add_input_args: + argumentHelper.add_arg_input_data_dir(required=True) + argumentHelper.add_arg_input_file_extension() + argumentHelper.add_arg_input_file_type() + argumentHelper.add_arg_input_text_field() + + argumentHelper.parser.add_argument( + "--config-file", + type=str, + help="Path to the semdedup config file", + required=True, + ) + # Set low default RMM pool size for classifier + # to allow pytorch to grow its memory usage + # by default + parser.set_defaults(rmm_pool_size="512MB") + parser.set_defaults(device="gpu") + parser.set_defaults(set_torch_to_use_rmm=False) + return parser diff --git a/nemo_curator/utils/semdedup_utils.py b/nemo_curator/utils/semdedup_utils.py new file mode 100644 index 00000000..be7b6e5a --- /dev/null +++ b/nemo_curator/utils/semdedup_utils.py @@ -0,0 +1,445 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import random +import shutil +import time +from typing import List, Tuple + +import cudf +import dask.bag as db +import dask.dataframe as dd +import numpy as np +import pandas as pd +import torch +from dask.distributed import progress + +from nemo_curator.utils.file_utils import expand_outdir_and_mkdir + + +def _assign_and_sort_clusters( + id_col: str, + kmeans_centroids_file: str, + nearest_cent_dir: str, + output_sorted_clusters_dir: str, + cluster_ids=List[int], + sim_metric: str = "cosine", + keep_hard: bool = True, + kmeans_with_cos_dist: bool = True, + logger: logging.Logger = None, +): + """ + Args: + id_col (str): The column name representing the unique identifier for each data point. + centroids_path (str): The location of the K-means centroids file. + nearest_cent_dir (str): The location of the nearest center files. + output_sorted_clusters_dir (str): The location to save the sorted clusters. + sim_metric (str): The similarity metric to use for clustering. Defaults to "cosine". + keep_hard (bool): When True, sorts cluster items in descending order by similarity to the cluster centroid. Defaults to True. + kmeans_with_cos_dist (bool): Whether to use cosine distance for K-means clustering. Defaults to True. + sorted_clusters_file_loc (str): The location to save the sorted clusters file. Defaults to an empty string. + cluster_ids (list): The range of cluster IDs to sort. + logger (logging.Logger): A logger object to log messages. Defaults to None. + + Returns: + None + """ + # Step 3: Sort each class/cluster + logger.info("Ranking...") + if os.path.exists(output_sorted_clusters_dir): + logger.info( + f"Removing existing sorted cluster directory: {output_sorted_clusters_dir}" + ) + shutil.rmtree(output_sorted_clusters_dir) + + expand_outdir_and_mkdir(output_sorted_clusters_dir) + + kmeans_centroids = np.load(kmeans_centroids_file) + start_time = time.time() + + cluster_ids_bag = db.from_sequence(cluster_ids, npartitions=len(cluster_ids)) + completed_count = cluster_ids_bag.map( + lambda cluster_c: rank_within_cluster( + id_col=id_col, + nearest_cent_dir=nearest_cent_dir, + output_sorted_clusters_dir=output_sorted_clusters_dir, + centroids=kmeans_centroids, + sim_metric=sim_metric, + keep_hard=keep_hard, + kmeans_with_cos_dist=kmeans_with_cos_dist, + cluster_ids=[cluster_c], + ) + ).compute() + + missing = len(cluster_ids) - sum(completed_count) + logger.info( + f"Completed {sum(completed_count)} clusters. Missing {missing} clusters." + ) + logger.info(f"Time for ranking: {(time.time() - start_time) / 60:.2f} mins") + logger.info("DONE!") + + +def rank_within_cluster( + id_col: str, + nearest_cent_dir: str, + output_sorted_clusters_dir: str, + centroids: np.ndarray, + sim_metric: str = "cosine", + keep_hard: bool = True, + kmeans_with_cos_dist: bool = False, + cluster_ids: List[int] = range(50000), +): + """ + Sorts each cluster's items by their distance to the cluster centroid. + + Args: + id_col (str): The column name representing the unique identifier for each data point. + nearest_cent_dir (str): The location of the nearest center files. + output_sorted_clusters_dir (str): The location to save the sorted clusters. + centroids (np.ndarray): The centroids for each cluster. + sim_metric (str): The similarity metric used to compute distances. Should be one of ["cosine"]. Defaults to "cosine". + keep_hard (bool): When True, sorts cluster items in descending order by similarity to the cluster centroid. Defaults to True. + kmeans_with_cos_dist (bool): Whether to use cosine distance for K-means clustering. Defaults to False. + cluster_ids (List[int]): The list of cluster IDs to process. Defaults to range(50000). + + Returns: + None + """ + assert sim_metric in [ + "cosine", + ], "sim_metric should be in ['cosine']" + + missing_files = 0 + for cluster_c in cluster_ids: + cluster_c_path = os.path.join(nearest_cent_dir, f"nearest_cent={cluster_c}") + if not os.path.exists(cluster_c_path): + missing_files += 1 + continue + + cluster_df = cudf.read_parquet( + cluster_c_path, columns=[id_col, "dist_to_cent", "embeddings"] + ) + embeds = torch.as_tensor( + cluster_df["embeddings"].list.leaves.values.reshape( + cluster_df.shape[0], -1 + ), + device="cuda", + ) + cluster_df = cluster_df.to_pandas() + + assert kmeans_with_cos_dist is False + + if sim_metric == "cosine": + cluster_c_centroid = torch.as_tensor(centroids[cluster_c], device="cuda") + sim_to_cent = torch.nn.CosineSimilarity(dim=1)(embeds, cluster_c_centroid) + sim_to_cent = sim_to_cent.cpu().numpy() + cluster_dists_to_cent = (1 - sim_to_cent).tolist() + elif sim_metric == "l2": + # Used when kmeans_with_cos_dist is True + cluster_dists_to_cent = list(cluster_df["dist_to_cent"]) + + cluster_label = np.full((len(cluster_df)), cluster_c).tolist() + example_id = list(cluster_df[id_col]) + sort_descending = keep_hard + cluster_sorted = sorted( + zip(example_id, cluster_dists_to_cent, cluster_label), + key=lambda x: x[2], + reverse=sort_descending, + ) # -- sort_descending = True for descending sort + + sorted_cluster_file_path = os.path.join( + output_sorted_clusters_dir, f"cluster_{cluster_c}.npy" + ) + np.save(sorted_cluster_file_path, cluster_sorted) + + return len(cluster_ids) - missing_files + + +def _semdedup( + cluster_reps: torch.Tensor, device: str +) -> Tuple[torch.Tensor, List[int]]: + # compute pairwise cos sim between cluster items, + # then replace to diagonal with zeros to ignore self similarity + cluster_reps.to(device) + pair_w_sim_matrix = cluster_reps @ (cluster_reps.T) + del cluster_reps + pair_w_sim_matrix.fill_diagonal_(0.0) + assert pair_w_sim_matrix.shape[0] == pair_w_sim_matrix.shape[1] + + triu_sim_mat = torch.triu(pair_w_sim_matrix, diagonal=1) + + M = torch.max(triu_sim_mat, dim=0)[0].cpu() + M1 = torch.max(triu_sim_mat, dim=0)[1].cpu().numpy().tolist() + return M, M1 + + +def get_cluster_reps( + cluster_id: int, emb_by_clust_dir: str, id_col: str, sorted_ids: np.ndarray +) -> torch.Tensor: + cluster_i_path = os.path.join(emb_by_clust_dir, f"nearest_cent={cluster_id}") + cluster_reps = cudf.read_parquet( + cluster_i_path, columns=["embeddings", id_col] + ).sort_values(by=id_col) + num = cluster_reps.shape[0] + + df_ = pd.DataFrame( + {"sorted_ids": sorted_ids, "inverse_sort": list(range(num))} + ).sort_values(by="sorted_ids") + cluster_reps["inverse_sort_id"] = df_["inverse_sort"].values + cluster_reps = cluster_reps.sort_values(by="inverse_sort_id") + + cluster_reps = torch.as_tensor( + cluster_reps["embeddings"].list.leaves.values.reshape(len(cluster_reps), -1), + device="cuda", + ) + return cluster_reps + + +def get_semantic_matches_per_cluster( + cluster_id: int, + emb_by_clust_dir: str, + sorted_clusters_dir: str, + id_col: str, + id_col_type: str, + eps_list: List[float], + output_dir: str, + which_to_keep: str, +) -> None: + + output_df_file_path = os.path.join(output_dir, f"cluster_{cluster_id}.parquet") + + sorted_file = os.path.join(sorted_clusters_dir, f"cluster_{cluster_id}.npy") + if not os.path.exists(sorted_file): + logging.info(f"{sorted_file} does not exist. Continue") + return + + cluster_i = np.load(sorted_file) + cluster_size = cluster_i.shape[0] + logging.info(f"{cluster_id}: cluster_size: {cluster_size}") + + if cluster_size == 1: + points_to_remove_df = pd.DataFrame() + points_to_remove_df["indices"] = [0] + for eps in eps_list: + points_to_remove_df[f"eps={eps}"] = [False] + points_to_remove_df.to_parquet(output_df_file_path) + return + + clutser_items_indices = list(range(cluster_size)) + + which_to_keep = which_to_keep.lower() + if which_to_keep == "random": + random.shuffle(clutser_items_indices) + cluster_i = cluster_i[clutser_items_indices] + elif which_to_keep == "easy": + clutser_items_indices = clutser_items_indices[::-1] + cluster_i = cluster_i[clutser_items_indices] + + text_ids = cluster_i[:, 0].astype(id_col_type) + + cluster_reps = get_cluster_reps(cluster_id, emb_by_clust_dir, id_col, text_ids) + M, M1 = _semdedup(cluster_reps, "cuda") + assert cluster_reps.shape[0] == len(text_ids) + + M1_id = [text_ids[m] for m in M1] + + points_to_remove_df = cudf.DataFrame() + points_to_remove_df["indices"] = clutser_items_indices + points_to_remove_df["id"] = text_ids + points_to_remove_df["max_id"] = M1_id + points_to_remove_df["cosine_sim_score"] = M.numpy().tolist() + + for eps in eps_list: + eps_points_to_remove = M > 1 - eps + points_to_remove_df[f"eps={eps}"] = eps_points_to_remove + + points_to_remove_df.to_parquet(output_df_file_path) + + +def get_num_records(file_path): + if not os.path.exists(file_path): + return 0 + with open(file_path, "rb") as f: + # Read the header of the npy file + version = np.lib.format.read_magic(f) + shape, _, _ = np.lib.format._read_array_header(f, version) + return shape[0] + + +def _get_empty_results_df(id_col, id_col_type): + meta_df = pd.DataFrame( + { + id_col: np.empty(0, dtype="int64"), + "dist": np.empty(0, dtype="float32"), + "cluster": np.empty(0, dtype="int32"), + } + ) + meta_df[id_col] = meta_df[id_col].astype(id_col_type) + return meta_df + + +def prune_single_cluster( + cluster_id: int, + id_col: str, + id_col_type: str, + sorted_clusters_dir: str, + semdedup_pruning_tables_dir: str, + eps: float, +) -> cudf.DataFrame: + """ + Processes data for a single cluster, applying pruning based on specified epsilon. + + Args: + cluster_id (int): The specific cluster ID to process. + id_col (str): The name of the ID column. + id_col_type (str): The data type of the ID column. + sorted_clusters_dir (str): Path to the sorted clusters directory. + semdedup_pruning_tables_dir (str): Path to the pruning tables directory. + eps (float): Epsilon value for pruning. + + Returns: + cudf.DataFrame: A DataFrame of the pruned cluster data + """ + sorted_fname = os.path.join(sorted_clusters_dir, f"cluster_{cluster_id}.npy") + if not os.path.exists(sorted_fname): + return _get_empty_results_df(id_col, id_col_type) + + cluster_data = np.load(sorted_fname) + df_cluster = cudf.DataFrame( + { + id_col: cluster_data[:, 0], + "dist": cluster_data[:, 1], + "cluster": cluster_data[:, 2], + } + ) + + df_cluster[id_col] = df_cluster[id_col].astype(id_col_type) + df_cluster["dist"] = df_cluster["dist"].astype("float32") + df_cluster["cluster"] = df_cluster["cluster"].astype("int32") + + cluster_df_fname = os.path.join( + semdedup_pruning_tables_dir, f"cluster_{cluster_id}.parquet" + ) + pruning_table = cudf.read_parquet(cluster_df_fname) + if pruning_table.shape[0] == 1: + return df_cluster + + # TODO: Fix this without going to host + items_to_keep = ( + pruning_table[pruning_table[f"eps={eps}"] == False]["id"].to_arrow().to_pylist() + ) + pruned_cluster = df_cluster[df_cluster[id_col].isin(items_to_keep)] + pruned_cluster[id_col] = pruned_cluster[id_col].astype(id_col_type) + return pruned_cluster + + +def extract_pruned_data( + id_col: str, + id_col_type: str, + sorted_clusters_dir: str, + semdedup_pruning_tables_dir: str, + eps: float, + n_clusters: int, + output_parquet_path: str, +) -> Tuple[int, int, int]: + """ + Extracts pruned data from sorted clusters and saves it to a CSV file. + + Args: + id_col (str): The name of the ID column. + id_col_type (str): The data type of the ID column. + sorted_clusters_dir (str): Path to the sorted clusters directory. + semdedup_pruning_tables_dir (str): Path to the pruning tables directory. + eps (float): Epsilon value for pruning. + n_clusters (int): Number of clusters. + output_csv_path (str): Path to save the output CSV file. + + Returns: + Tuple[int, int, int]: Number of kept records, removed records, and total records. + """ + + results_df = dd.from_map( + prune_single_cluster, + range(n_clusters), + id_col=id_col, + id_col_type=id_col_type, + sorted_clusters_dir=sorted_clusters_dir, + semdedup_pruning_tables_dir=semdedup_pruning_tables_dir, + eps=eps, + ) + results_df[id_col] = results_df[id_col].astype(id_col_type) + results_df = results_df.persist() + progress(results_df) + + results_df.to_parquet(output_parquet_path) + total_kept = len(results_df) + + np_files = [ + os.path.join(sorted_clusters_dir, f"cluster_{i}.npy") for i in range(n_clusters) + ] + total_records = sum(get_num_records(file_path) for file_path in np_files) + # Aggregate results + total_removed = total_records - total_kept + return total_kept, total_removed, total_records + + +def extract_dedup_data( + eps, + n_clusters, + id_col, + id_col_type, + sorted_clusters_dir, + semdedup_pruning_tables_dir, + output_summary_file, + output_parquet_path, + logger: logging.Logger, +) -> None: + """ + Extracts deduplicated data based on provided parameters and logs the process. + + Args: + + """ + + kept, removed, total = extract_pruned_data( + id_col=id_col, + id_col_type=id_col_type, + sorted_clusters_dir=sorted_clusters_dir, + semdedup_pruning_tables_dir=semdedup_pruning_tables_dir, + eps=eps, + n_clusters=n_clusters, + output_parquet_path=output_parquet_path, + ) + + logger.info( + f"DONE saving {kept} out of {total}. Removed: {removed}. Epsilon: {eps:.4f}" + ) + result_dict = { + "eps": [eps], + "kept": [kept], + "removed": [removed], + "total": [total], + } + df = pd.DataFrame(result_dict) + df.to_csv(output_summary_file, index=False) + + fps = [ + os.path.join(output_parquet_path, file_name) + for file_name in os.listdir(output_parquet_path) + ] + ids_to_keep_df = dd.from_map(cudf.read_parquet, fps) + return ids_to_keep_df diff --git a/setup.py b/setup.py index 933f4c2d..a6884061 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,8 @@ "presidio-anonymizer==2.2.351", "usaddress==0.5.10", "nemo_toolkit[nlp]>=1.23.0", - "crossfit @ git+https://github.com/rapidsai/crossfit.git@1ee3de4", + "Cython", + "crossfit @ git+https://github.com/rapidsai/crossfit.git@0.0.2", # justext installation breaks without lxml[html_clean] # due to this: https://github.com/miso-belica/jusText/issues/47 "lxml[html_clean]", @@ -107,6 +108,9 @@ "quality_classifier_inference=nemo_curator.scripts.quality_classifier_inference:console_script", "verify_classification_results=nemo_curator.scripts.verify_classification_results:console_script", "blend_datasets=nemo_curator.scripts.blend_datasets:console_script", + "semdedup_extract_embeddings=nemo_curator.scripts.semdedup.compute_embeddings:console_script", + "semdedup_clustering=nemo_curator.scripts.semdedup.clustering:console_script", + "semdedup_extract_dedup_ids=nemo_curator.scripts.semdedup.extract_dedup_data:console_script", ], }, )