diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 835f354bd5..0d0f33d9a6 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -30,6 +30,7 @@ from distilabel.steps.deita import DeitaFiltering from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour +from distilabel.steps.filtering.embedding import EmbeddingDedup from distilabel.steps.filtering.minhash import MinHashDedup from distilabel.steps.formatting.conversation import ConversationTemplate from distilabel.steps.formatting.dpo import ( @@ -79,6 +80,7 @@ "LoadDataFromDisk", "LoadDataFromFileSystem", "LoadDataFromHub", + "EmbeddingDedup", "MinHashDedup", "make_generator_step", "PushToHub", diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py index adf3fe6c04..98b646d9ee 100644 --- a/src/distilabel/steps/embeddings/nearest_neighbour.py +++ b/src/distilabel/steps/embeddings/nearest_neighbour.py @@ -46,6 +46,8 @@ class FaissNearestNeighbour(GlobalStep): search_batch_size: the number of rows to include in a search batch. The value can be adjusted to maximize the resources usage or to avoid OOM issues. Defaults to `50`. + train_size: If the index needs a training step, specifies how many vectors will be + used to train the index. Runtime parameters: - `device`: the CUDA device ID or a list of IDs to be used. If negative integer, @@ -60,6 +62,8 @@ class FaissNearestNeighbour(GlobalStep): - `search_batch_size`: the number of rows to include in a search batch. The value can be adjusted to maximize the resources usage or to avoid OOM issues. Defaults to `50`. + - `train_size`: If the index needs a training step, specifies how many vectors will + be used to train the index. Input columns: - embedding (`List[Union[float, int]]`): a sentence embedding. @@ -148,6 +152,10 @@ class FaissNearestNeighbour(GlobalStep): description="The number of rows to include in a search batch. The value can be adjusted" " to maximize the resources usage or to avoid OOM issues.", ) + train_size: Optional[RuntimeParameter[int]] = Field( + default=None, + description="If the index needs a training step, specifies how many vectors will be used to train the index.", + ) def load(self) -> None: super().load() @@ -176,11 +184,14 @@ def _build_index(self, inputs: List[Dict[str, Any]]) -> Dataset: The build `datasets.Dataset` with its `faiss` index. """ dataset = Dataset.from_list(inputs) + if self.train_size is not None and self.string_factory: + self._logger.info("🏋️‍♀️ Starting Faiss index training...") dataset.add_faiss_index( column="embedding", device=self.device, # type: ignore string_factory=self.string_factory, metric_type=self.metric_type, + train_size=self.train_size, ) return dataset diff --git a/src/distilabel/steps/filtering/embedding.py b/src/distilabel/steps/filtering/embedding.py new file mode 100644 index 0000000000..cb1e710374 --- /dev/null +++ b/src/distilabel/steps/filtering/embedding.py @@ -0,0 +1,192 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +import numpy as np +from pydantic import Field +from rich.progress import track +from typing_extensions import override + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.base import GlobalStep, StepInput + +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + + +class EmbeddingDedup(GlobalStep): + """Deduplicates text using embeddings. + + `EmbeddingDedup` is a Step that detects near-duplicates in datasets, using + embeddings to compare the similarity between the texts. The typical workflow with this step + would include having a dataset with embeddings precomputed, and then (possibly using the + `FaissNearestNeighbour`) using the `nn_indices` and `nn_scores`, determine the texts that + are duplicate. + + Attributes: + threshold: the threshold to consider 2 examples as duplicates. + It's dependent on the type of index that was used to generate the embeddings. + For example, if the embeddings were generated using cosine similarity, a threshold + of `0.9` would make all the texts with a cosine similarity above the value + duplicates. Higher values detect less duplicates in such an index, but that should + be taken into account when building it. Defaults to `0.9`. + + Runtime Parameters: + - `threshold`: the threshold to consider 2 examples as duplicates. + + Input columns: + - nn_indices (`List[int]`): a list containing the indices of the `k` nearest neighbours + in the inputs for the row. + - nn_scores (`List[float]`): a list containing the score or distance to each `k` + nearest neighbour in the inputs. + + Output columns: + - keep_row_after_embedding_filtering (`bool`): boolean indicating if the piece `text` is + not a duplicate i.e. this text should be kept. + + Categories: + - filtering + + Examples: + + Deduplicate a list of texts using embedding information: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import EmbeddingDedup + from distilabel.steps import LoadDataFromDicts + + with Pipeline() as pipeline: + data = LoadDataFromDicts( + data=[ + { + "persona": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.", + "embedding": [ + 0.018477669046149742, + -0.03748236608841726, + 0.001919870620352492, + 0.024918478063770535, + 0.02348063521315178, + 0.0038251285566308375, + -0.01723884983037716, + 0.02881971942372201, + ], + "nn_indices": [0, 1], + "nn_scores": [ + 0.9164746999740601, + 0.782106876373291, + ], + }, + { + "persona": "A music teacher or instructor focused on theoretical and practical piano lessons.", + "embedding": [ + -0.0023464179614082125, + -0.07325472251663565, + -0.06058678419516501, + -0.02100326928586996, + -0.013462744792362657, + 0.027368447064244242, + -0.003916070100455717, + 0.01243614518480423, + ], + "nn_indices": [0, 2], + "nn_scores": [ + 0.7552462220191956, + 0.7261884808540344, + ], + }, + { + "persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.", + "embedding": [ + -0.01630817942328242, + -0.023760151552345232, + -0.014249650090627883, + -0.005713686451446624, + -0.016033059279131567, + 0.0071440908501058786, + -0.05691099643425161, + 0.01597412704817784, + ], + "nn_indices": [1, 2], + "nn_scores": [ + 0.8107735514640808, + 0.7172299027442932, + ], + }, + ], + batch_size=batch_size, + ) + # In general you should do something like this before the deduplication step, to obtain the + # `nn_indices` and `nn_scores`. In this case the embeddings are already normalized, so there's + # no need for it. + # nn = FaissNearestNeighbour( + # k=30, + # metric_type=faiss.METRIC_INNER_PRODUCT, + # search_batch_size=50, + # train_size=len(dataset), # The number of embeddings to use for training + # string_factory="IVF300_HNSW32,Flat" # To use an index (optional, maybe required for big datasets) + # ) + # Read more about the `string_factory` here: + # https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index + + embedding_dedup = EmbeddingDedup( + threshold=0.8, + input_batch_size=batch_size, + ) + + data >> embedding_dedup + + if __name__ == "__main__": + distiset = pipeline.run(use_cache=False) + ds = distiset["default"]["train"] + # Filter out the duplicates + ds_dedup = ds.filter(lambda x: x["keep_row_after_embedding_filtering"]) + ``` + """ + + threshold: Optional[RuntimeParameter[float]] = Field( + default=0.9, + description="The threshold to consider 2 examples as duplicates. It's dependent " + "on the type of index that was used to generate the embeddings. For example, if " + "the embeddings were generated using cosine similarity, a threshold of `0.9` " + "would make all the texts with a cosine similarity above the value duplicates. " + "Higher values detect less duplicates in such an index, but that should be " + "taken into account when building it.", + ) + + @property + def inputs(self) -> List[str]: + return ["nn_scores", "nn_indices"] + + @property + def outputs(self) -> List[str]: + return ["keep_row_after_embedding_filtering"] + + @override + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + rows_to_remove = set() + + for input in track(inputs, description="Running Embedding deduplication..."): + input["keep_row_after_embedding_filtering"] = True + indices_scores = np.array(input["nn_scores"]) > self.threshold + indices = np.array(input["nn_indices"])[indices_scores] + if len(indices) > 0: # If there are any rows found over the threshold + rows_to_remove.update(list(indices)) + + # Remove duplicates and get the list of rows to remove + for idx in rows_to_remove: + inputs[idx]["keep_row_after_embedding_filtering"] = False + + yield inputs diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py index 5f82f64731..ea72c9a0c6 100644 --- a/src/distilabel/steps/generators/huggingface.py +++ b/src/distilabel/steps/generators/huggingface.py @@ -243,20 +243,18 @@ def _dataset_info(self) -> Dict[str, DatasetInfo]: Returns: The dataset information. """ - repo_id = self.repo_id - config = self.config try: - return get_dataset_infos(repo_id) + return get_dataset_infos(self.repo_id) except Exception as e: # The previous could fail in case of a internet connection issues. # Assuming the dataset is already loaded and we can get the info from the loaded dataset, otherwise it will fail anyway. self._logger.warning( f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}" ) - ds = load_dataset(repo_id, config=self.config, split=self.split) - if config: - return ds[config].info + ds = load_dataset(self.repo_id, config=self.config, split=self.split) + if self.config: + return ds[self.config].info return ds.info diff --git a/src/distilabel/steps/generators/utils.py b/src/distilabel/steps/generators/utils.py index de119ab2ef..27455119bd 100644 --- a/src/distilabel/steps/generators/utils.py +++ b/src/distilabel/steps/generators/utils.py @@ -32,6 +32,7 @@ def make_generator_step( input_mappings: Optional[Dict[str, str]] = None, output_mappings: Optional[Dict[str, str]] = None, resources: StepResources = StepResources(), + repo_id: str = "placeholder", ) -> "GeneratorStep": """Helper method to create a `GeneratorStep` from a dataset, to simplify @@ -42,6 +43,10 @@ def make_generator_step( input_mappings: Applies the same as any other step. Defaults to `None`. output_mappings: Applies the same as any other step. Defaults to `None`. resources: Applies the same as any other step. Defaults to `StepResources()`. + repo_id: The repository ID to use in the `LoadDataFromHub` step. + This shouldn't be necessary, but in case of error, the dataset will try to be loaded + using `load_dataset` internally. If that case happens, the `repo_id` will be used. + Defaults to `"placeholder"`. Raises: ValueError: If the format is different from the ones supported. @@ -74,12 +79,13 @@ def make_generator_step( loader = LoadDataFromHub( pipeline=pipeline, - repo_id="placeholder_name", + repo_id=repo_id, batch_size=batch_size, input_mappings=input_mappings or {}, output_mappings=output_mappings or {}, resources=resources, ) + super(loader.__class__, loader).load() # Ensure the logger is loaded loader._dataset = dataset loader.num_examples = len(dataset) loader._dataset_info = {"default": dataset.info} diff --git a/tests/integration/test_embedding_dedup.py b/tests/integration/test_embedding_dedup.py new file mode 100644 index 0000000000..7806cf6761 --- /dev/null +++ b/tests/integration/test_embedding_dedup.py @@ -0,0 +1,130 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +import faiss +import numpy as np + +from distilabel.pipeline import Pipeline +from distilabel.steps import FaissNearestNeighbour, LoadDataFromDicts, StepInput, step +from distilabel.steps.filtering.embedding import EmbeddingDedup + +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + + +SAMPLE_DATA = [ + { + "text": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.", + "embedding": [ + 0.018477669046149742, + -0.03748236608841726, + 0.001919870620352492, + 0.024918478063770535, + 0.02348063521315178, + 0.0038251285566308375, + -0.01723884983037716, + 0.02881971942372201, + ], + }, + { + "text": "A music teacher or instructor focused on theoretical and practical piano lessons.", + "embedding": [ + -0.0023464179614082125, + -0.07325472251663565, + -0.06058678419516501, + -0.02100326928586996, + -0.013462744792362657, + 0.027368447064244242, + -0.003916070100455717, + 0.01243614518480423, + ], + }, + { + "text": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.", + "embedding": [ + -0.01630817942328242, + -0.023760151552345232, + -0.014249650090627883, + -0.005713686451446624, + -0.016033059279131567, + 0.0071440908501058786, + -0.05691099643425161, + 0.01597412704817784, + ], + }, + { + "text": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.", + "embedding": [ + -0.01630817942328242, + -0.023760151552345232, + -0.014249650090627883, + -0.005713686451446624, + -0.016033059279131567, + 0.0071440908501058786, + -0.05691099643425161, + 0.01597412704817784, + ], + }, +] + + +@step(inputs=["embedding"], outputs=["embedding"]) +def NormalizeEmbeddings(inputs: StepInput) -> "StepOutput": + # Normalize a vector to have length 1 + for input in inputs: + norm = np.linalg.norm(input["embedding"]) + if norm == 0: + print("Cannot normalize a zero vector") + continue + input["embedding"] = input["embedding"] / norm + yield inputs + + +def test_embedding_deduplication() -> None: + with Pipeline() as pipeline: + loader = LoadDataFromDicts( + data=SAMPLE_DATA * 20, + batch_size=50, + ) + batch_size = 50 + + # NOTE: Guide to choose an index: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index + nn = FaissNearestNeighbour( + k=3, + metric_type=faiss.METRIC_INNER_PRODUCT, + search_batch_size=50, + # string_factory="IVF300_HNSW32,Flat", + # train_size=len(dataset), + input_batch_size=batch_size, + ) + + embedding_dedup = EmbeddingDedup( + threshold=0.99, + input_batch_size=batch_size, + ) + normalize = NormalizeEmbeddings() + loader >> normalize >> nn >> embedding_dedup + + distiset = pipeline.run(use_cache=False) + + ds = distiset["default"]["train"] + ds_dedup = ds.filter(lambda x: x["keep_row_after_embedding_filtering"]) + print(len(ds_dedup)) + assert len(ds_dedup) == 71 + + +if __name__ == "__main__": + test_embedding_deduplication() diff --git a/tests/unit/steps/filtering/test_embeddings.py b/tests/unit/steps/filtering/test_embeddings.py new file mode 100644 index 0000000000..354777bd94 --- /dev/null +++ b/tests/unit/steps/filtering/test_embeddings.py @@ -0,0 +1,104 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from distilabel.steps.filtering.embedding import EmbeddingDedup + +SAMPLE_DATA = [ + { + "persona": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.", + "embedding": [ + 0.018477669046149742, + -0.03748236608841726, + 0.001919870620352492, + 0.024918478063770535, + 0.02348063521315178, + 0.0038251285566308375, + -0.01723884983037716, + 0.02881971942372201, + ], + "nn_indices": [0, 1], + "nn_scores": [ + 0.9164746999740601, + 0.782106876373291, + ], + }, + { + "persona": "A music teacher or instructor focused on theoretical and practical piano lessons.", + "embedding": [ + -0.0023464179614082125, + -0.07325472251663565, + -0.06058678419516501, + -0.02100326928586996, + -0.013462744792362657, + 0.027368447064244242, + -0.003916070100455717, + 0.01243614518480423, + ], + "nn_indices": [0, 2], + "nn_scores": [ + 0.7552462220191956, + 0.7261884808540344, + ], + }, + { + "persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.", + "embedding": [ + -0.01630817942328242, + -0.023760151552345232, + -0.014249650090627883, + -0.005713686451446624, + -0.016033059279131567, + 0.0071440908501058786, + -0.05691099643425161, + 0.01597412704817784, + ], + "nn_indices": [1, 2], + "nn_scores": [ + 0.8107735514640808, + 0.7172299027442932, + ], + }, + { + "persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.", + "embedding": [ + -0.01630817942328242, + -0.023760151552345232, + -0.014249650090627883, + -0.005713686451446624, + -0.016033059279131567, + 0.0071440908501058786, + -0.05691099643425161, + 0.01597412704817784, + ], + "nn_indices": [], + "nn_scores": [], + }, +] + + +class TestEmbeddingDedup: + @pytest.mark.parametrize( + "threshold, keep_row_after_embedding_filtering", + [(0.1, 1), (0.9, 3), (0.99999, 4)], + ) + def test_process( + self, threshold: float, keep_row_after_embedding_filtering: int + ) -> None: + step = EmbeddingDedup(threshold=threshold) + step.load() + result = next(step.process(SAMPLE_DATA)) + duplicated = [r["keep_row_after_embedding_filtering"] for r in result] + assert sum(duplicated) == keep_row_after_embedding_filtering