diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index 0d7f1a6aa..e268e6676 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -111,11 +111,10 @@ class SemDedupConfig(BaseConfig): id_col_name (str): Column name for ID. id_col_type (str): Column type for ID. input_column (str): Input column for embeddings. - input_file_type (str): File type for input embeddings. embeddings_save_loc (str): Location to save embeddings. - model_name_or_path (str): Model name or path for embeddings. - batch_size (int): Inital Batch size for processing embeddings. - max_mem_gb (int): Maximum memory in GB for embeddings. + embedding_model_name_or_path (str): Model name or path for embeddings. + embedding_batch_size (int): Inital Batch size for processing embeddings. + embedding_max_mem_gb (int): Maximum memory in GB for embeddings. clustering_save_loc (str): Location to save clustering results. n_clusters (int): Number of clusters. seed (int): Seed for clustering. @@ -124,7 +123,8 @@ class SemDedupConfig(BaseConfig): which_to_keep (str): Which duplicates to keep. largest_cluster_size_to_process (int): Largest cluster size to process. sim_metric (str): Similarity metric for deduplication. - eps (str): Epsilon values to calculate if semantically similar or not + eps_thresholds (str): Epsilon thresholds to calculate if semantically similar or not + eps_to_extract (float): Epsilon value to extract deduplicated data. """ cache_dir: str @@ -134,7 +134,6 @@ class SemDedupConfig(BaseConfig): input_column: str = "text" # Embeddings - input_file_type: str = "json" embeddings_save_loc: str = "embeddings" embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: int = 128 @@ -154,7 +153,7 @@ class SemDedupConfig(BaseConfig): # Extract dedup config eps_thresholds: str = "0.01 0.001" - eps_to_extract: str = "0.01" + eps_to_extract: float = 0.01 def __post_init__(self): self.eps_thresholds = [float(x) for x in self.eps_thresholds.split()] @@ -162,3 +161,8 @@ def __post_init__(self): raise ValueError( "Finding sem-dedup requires a cache directory accessible via all workers to store intermediates" ) + + if self.eps_to_extract not in self.eps_thresholds: + raise ValueError( + f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds {self.eps_thresholds}" + ) diff --git a/nemo_curator/modules/semantic_dedup.py b/nemo_curator/modules/semantic_dedup.py index 9bb171fcf..0acbf0c94 100644 --- a/nemo_curator/modules/semantic_dedup.py +++ b/nemo_curator/modules/semantic_dedup.py @@ -118,9 +118,9 @@ def load_tokenizer(self): class EmbeddingCreator: def __init__( self, - model_name_or_path: str, - max_memory: str, - batch_size: int, + embeddings_model_name_or_path: str, + embedding_max_mem_gb: str, + embedding_batch_size: int, embedding_output_dir: str, input_column: str = "text", write_embeddings_to_disk: bool = True, @@ -131,9 +131,9 @@ def __init__( Initializes an EmbeddingCreator for generating embeddings using the specified model configurations. Args: - model_name_or_path (str): The path or identifier for the model used to generate embeddings. - max_memory (str): Maximum memory usage for the embedding process. - batch_size (int): Number of samples to process in each batch. + embeddings_model_name_or_path (str): The path or identifier for the model used to generate embeddings. + embedding_max_mem_gb (str): Maximum memory usage for the embedding process. + embedding_batch_size (int): Number of samples to process in each batch. embedding_output_dir (str): Directory path where embeddings will be saved. input_column (str): Column name from the data to be used for embedding generation, defaults to "text". write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True. @@ -153,9 +153,10 @@ def __init__( """ self.embeddings_config = EmbeddingConfig( - model_name_or_path=model_name_or_path, max_mem_gb=max_memory + model_name_or_path=embeddings_model_name_or_path, + max_mem_gb=embedding_max_mem_gb, ) - self.batch_size = batch_size + self.batch_size = embedding_batch_size self.logger = self._setup_logger(logger) self.embedding_output_dir = embedding_output_dir self.input_column = input_column diff --git a/nemo_curator/scripts/semdedup/clustering.py b/nemo_curator/scripts/semdedup/clustering.py index 258572ddf..51287270a 100644 --- a/nemo_curator/scripts/semdedup/clustering.py +++ b/nemo_curator/scripts/semdedup/clustering.py @@ -80,7 +80,23 @@ def main(args): def attach_args(): - parser = ArgumentHelper.parse_semdedup_args(add_input_args=False) + parser = ArgumentHelper.parse_semdedup_args( + description=( + "Performs clustering on the computed embeddings of a collection of documents. " + "This script requires that the embeddings have been created beforehand using: " + "semdedup_extract_embeddings" + "Input arguments include: " + "--config-file for the path to the semdedup config file. " + "Important configuration parameters include: " + " cache_dir for the directory to store cache," + " clustering_save_loc for the location to save clustering results," + " n_clusters for the number of clusters," + " seed for the seed for clustering," + " max_iter for the maximum iterations for clustering," + " Kmeans_with_cos_dist for using KMeans with cosine distance," + ), + add_input_args=False, + ) return parser diff --git a/nemo_curator/scripts/semdedup/compute_embeddings.py b/nemo_curator/scripts/semdedup/compute_embeddings.py index fc5795023..b96c8d38f 100644 --- a/nemo_curator/scripts/semdedup/compute_embeddings.py +++ b/nemo_curator/scripts/semdedup/compute_embeddings.py @@ -67,9 +67,9 @@ def main(args): # Can repartition here if needed # ddf = ddf.repartition(partition_size="64MB") embedding_creator = EmbeddingCreator( - model_name_or_path=semdedup_config.embedding_model_name_or_path, - max_memory=semdedup_config.embedding_max_mem_gb, - batch_size=semdedup_config.embedding_batch_size, + embedding_model_name_or_path=semdedup_config.embedding_model_name_or_path, + embedding_max_mem_gb=semdedup_config.embedding_max_mem_gb, + embedding_batch_size=semdedup_config.embedding_batch_size, embedding_output_dir=os.path.join( semdedup_config.cache_dir, semdedup_config.embeddings_save_loc ), @@ -85,7 +85,28 @@ def main(args): def attach_args(): - parser = ArgumentHelper.parse_semdedup_args(add_input_args=True) + parser = ArgumentHelper.parse_semdedup_args( + description=( + "Computes the embeddings of a collection of documents using the specified model. " + "The model is specified in the config file using embedding_model_name_or_path (e.g. 'sentence-transformers/paraphrase-MiniLM-L6-v2'). " + "The embeddings are saved in the specified cache directory under the embeddings_save_loc directory. " + "Input arguments include: " + "--input_data_dir for the directory containing input data files, " + "--input_file_extension for specifying the file extension of input files (e.g., .jsonl), " + "--input_file_type for the type of input files (e.g., json, csv), " + "--input_text_field for the field in the input files containing the text data to be embedded. " + "Additional configuration can be provided via the --config-file argument. " + "Important configuration parameters include: " + " cache_dir for the directory to store cache" + " num_files for the number of files to process (default is -1, meaning all files)," + " input_column for specifying the input column for embeddings," + " embeddings_save_loc for the location to save embeddings," + " embedding_model_name_or_path for the model name or path for embeddings," + " embedding_batch_size for the batch size for processing embeddings," + " embedding_max_mem_gb for the maximum memory in GB for embeddings" + ), + add_input_args=True, + ) return parser diff --git a/nemo_curator/scripts/semdedup/extract_dedup_data.py b/nemo_curator/scripts/semdedup/extract_dedup_data.py index 36ed142fc..ca5016b98 100755 --- a/nemo_curator/scripts/semdedup/extract_dedup_data.py +++ b/nemo_curator/scripts/semdedup/extract_dedup_data.py @@ -57,11 +57,26 @@ def main(args): client.cancel(client.futures, force=True) client.close() - return def attach_args(): - parser = ArgumentHelper.parse_semdedup_args(add_input_args=False) + parser = ArgumentHelper.parse_semdedup_args( + description=( + "Extracts deduplicated data from the clustered embeddings of a collection of documents. " + "This script requires that embeddings and clustering have been performed beforehand using the specified configurations. " + "earlier using semdedup_extract_embeddings and semdedup_cluster_embeddings." + "Input arguments include: " + "--config-file for the path to the semdedup config file. " + "Important configuration parameters include:" + "- cache_dir for the directory to store cache" + "which_to_keep for specifying which duplicates to keep," + "largest_cluster_size_to_process for the largest cluster size to process," + "sim_metric for the similarity metric for deduplication," + "eps_thresholds for epsilon thresholds to calculate if semantically similar or not" + "and eps_to_extract for the epsilon value to extract deduplicated data." + ), + add_input_args=False, + ) return parser diff --git a/nemo_curator/utils/semdedup_utils.py b/nemo_curator/utils/semdedup_utils.py index b7fcaa075..11f1deb31 100644 --- a/nemo_curator/utils/semdedup_utils.py +++ b/nemo_curator/utils/semdedup_utils.py @@ -52,7 +52,7 @@ def _assign_and_sort_clusters( keep_hard (bool): When True, sorts cluster items in descending order by similarity to the cluster centroid. Defaults to True. kmeans_with_cos_dist (bool): Whether to use cosine distance for K-means clustering. Defaults to True. sorted_clusters_file_loc (str): The location to save the sorted clusters file. Defaults to an empty string. - cluster_ids (list): The range of cluster IDs to sort. Defaults to range(5000). + cluster_ids (list): The range of cluster IDs to sort. logger (logging.Logger): A logger object to log messages. Defaults to None. Returns: @@ -268,7 +268,6 @@ def get_semantic_matches_per_cluster( points_to_remove_df[f"eps={eps}"] = eps_points_to_remove points_to_remove_df.to_parquet(output_df_file_path) - return None def get_num_records(file_path):