diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 375df5b9..58442c7b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,5 +40,3 @@ jobs: # TODO: Remove env variable when gpu dependencies are optional run: | RAPIDS_NO_INITIALIZE=1 python -m pytest -v --cpu - - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..6f550fc9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +default_language_version: + python: python3 + +ci: + autofix_prs: true + autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' + autoupdate_schedule: quarterly + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-case-conflict + - id: check-yaml + - id: detect-private-key + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: trailing-whitespace + + - repo: https://github.com/psf/black + rev: 24.3.0 + hooks: + - id: black + name: Format code + + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + name: Format imports + exclude: docs/ diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 4861cafe..00000000 --- a/.style.yapf +++ /dev/null @@ -1,3 +0,0 @@ -[style] -based_on_style = google -indent_width = 2 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3bd14fc1..b8ba733f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -52,7 +52,7 @@ We use ``black`` as our style guide. To fix your format run `pip install pre-com 1. Minimize the use of ``**kwargs``. 1. ``RaiseError`` is preferred to ``assert``. Write: ```if X: raise Error``` instead of ```assert X```. 1. Classes are preferred to standalone methods. -1. Methods should be atomic. A method shouldn't be longer than 75 lines, e.g. can be fit into the computer screen without scrolling. +1. Methods should be atomic. A method shouldn't be longer than 88 lines, e.g. can be fit into the computer screen without scrolling. 1. If a method has arguments that don't fit into one line, each argument should be in its own line for readability. 1. Add ``__init__.py`` for every folder. 1. F-strings are prefered to formatted strings. diff --git a/README.md b/README.md index e1b9fe72..6eec3138 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ We currently support the following data-curation modules. For more details on ea - [Text reformatting and cleaning](docs/user-guide/LanguageIdentificationUnicodeFormatting.rst) - Fix unicode decoding errors via [ftfy](https://ftfy.readthedocs.io/en/latest/) - [Quality filtering](docs/user-guide/QualityFiltering.rst) - - Multilingual heuristic-based filtering + - Multilingual heuristic-based filtering - Classifier-based filtering via [fastText](https://fasttext.cc/) - [Document-level deduplication](docs/user-guide/GpuDeduplication.rst) - Both exact and fuzzy deduplication are accelerated using cuDF and Dask. @@ -79,7 +79,7 @@ Note: This is not the only way to run NeMo Curator on SLURM. There are example s ## Module Ablation and Compute Performance -The modules within NeMo Curator were in large part designed to curate high-quality documents from Common Crawl snapshots and to be able to do so +The modules within NeMo Curator were in large part designed to curate high-quality documents from Common Crawl snapshots and to be able to do so in a scalable manner. In order to assess the quality of the Common Crawl documents curated by the modules in NeMo Curator, we performed a series of ablation experiments in which we trained a 357M-parameter GPT-style model on the datasets resulting from the different stages of our data curation pipeline implemented in NeMo Curator. The figure below demonstrates that the different data curation modules implemented within NeMo Curator @@ -89,7 +89,7 @@ lead to improved model zero-shot downstream task performance. drawing

-In terms of scalability and compute performance, using the RAPIDS + Dask fuzzy deduplication, we are able to deduplicate the 1.1 Trillion token Red Pajama dataset in 1.8 hours using 64 A100s. +In terms of scalability and compute performance, using the RAPIDS + Dask fuzzy deduplication, we are able to deduplicate the 1.1 Trillion token Red Pajama dataset in 1.8 hours using 64 A100s. Additionally, using the CPU-based modules the table below shows the time required and resulting data size reduction of each step of processing the [Common Crawl snapshot from November/December of 2020](https://commoncrawl.org/2020/12/nov-dec-2020-crawl-archive-now-available/) using 30 CPU nodes (with hardware similar to the `c5.24xlarge` [Amazon AWS C5 instance](https://aws.amazon.com/ec2/instance-types/c5/)): @@ -128,4 +128,4 @@ Additionally, using the CPU-based modules the table below shows the time require As mentioned above, the modules within NeMo Curator enable users to scale data-mining and NLP processing tasks to many nodes within a compute cluster. The modules accomplish this using [Dask](https://www.dask.org/) with [cuDF](https://docs.rapids.ai/api/cudf/nightly/user_guide/10min/) (for the GPU-accelerated modules). -At the core of the NeMo Curator, `DocumentDataset` (the main dataset class) is just a simple wrapper around a Dask dataframe. Dask allows NeMo Curator to scale to arbitrary cluster sizes, and it supports a variety of distributed computing platforms. It supports reading and writing to different file formats, and it can balance these operations among nodes in the cluster. Importantly, Dask also supports the RAPIDS cuDF library for GPU-acclerated exact and fuzzy deduplication. \ No newline at end of file +At the core of the NeMo Curator, `DocumentDataset` (the main dataset class) is just a simple wrapper around a Dask dataframe. Dask allows NeMo Curator to scale to arbitrary cluster sizes, and it supports a variety of distributed computing platforms. It supports reading and writing to different file formats, and it can balance these operations among nodes in the cluster. Importantly, Dask also supports the RAPIDS cuDF library for GPU-acclerated exact and fuzzy deduplication. diff --git a/SECURITY.md b/SECURITY.md index 2be787ab..34137c32 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -21,4 +21,4 @@ While NVIDIA currently does not have a bug bounty program, we do offer acknowled ## NVIDIA Product Security -For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security \ No newline at end of file +For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security diff --git a/config/arxiv_builder.yaml b/config/arxiv_builder.yaml index 007566d6..7bc5cb66 100644 --- a/config/arxiv_builder.yaml +++ b/config/arxiv_builder.yaml @@ -1,11 +1,11 @@ download_module: nemo_curator.download.arxiv.ArxivDownloader download_params: {} iterator_module: nemo_curator.download.arxiv.ArxivIterator -iterator_params: +iterator_params: log_frequency: 1000 extract_module: nemo_curator.download.arxiv.ArxivExtractor extract_params: {} format: text: str id: str - source_id: str \ No newline at end of file + source_id: str diff --git a/config/cc_warc_builder.yaml b/config/cc_warc_builder.yaml index 3e6a8ed9..7975ab80 100644 --- a/config/cc_warc_builder.yaml +++ b/config/cc_warc_builder.yaml @@ -9,4 +9,4 @@ format: language: str url: str warc_id: str - source_id: str \ No newline at end of file + source_id: str diff --git a/config/heuristic_filter_code.yaml b/config/heuristic_filter_code.yaml index 7d6b36e4..897ce4d0 100644 --- a/config/heuristic_filter_code.yaml +++ b/config/heuristic_filter_code.yaml @@ -1,7 +1,7 @@ input_field: text filters: # The filters below define a chain of heuristic filters to be applied to each document in a corpus. - # This particular cascade of filters is intended to filter Python code data. + # This particular cascade of filters is intended to filter Python code data. # The filter listed at the top will be applied first, and the following filters will be applied in # the order they appear in this file. Each filter can be removed and re-ordered as desired. # Change this based on the language of the data diff --git a/config/heuristic_filter_en.yaml b/config/heuristic_filter_en.yaml index 4e3bbb79..d4c05f97 100644 --- a/config/heuristic_filter_en.yaml +++ b/config/heuristic_filter_en.yaml @@ -1,7 +1,7 @@ input_field: text filters: # The filters below define a chain of heuristic filters to be applied to each document in a corpus. - # This particular cascade of filters is intended to filter English language data. + # This particular cascade of filters is intended to filter English language data. # The filter listed at the top will be applied first, and the following filters will be applied in # the order they appear in this file. Each filter can be removed and re-ordered as desired. - name: nemo_curator.filters.heuristic_filter.NonAlphaNumericFilter @@ -14,16 +14,16 @@ filters: params: max_number_to_text_ratio: 0.15 - name: nemo_curator.filters.heuristic_filter.UrlsFilter - params: + params: max_url_to_text_ratio: 0.2 - name: nemo_curator.filters.heuristic_filter.WhiteSpaceFilter - params: + params: max_white_space_ratio: 0.25 - name: nemo_curator.filters.heuristic_filter.ParenthesesFilter - params: + params: max_parentheses_ratio: 0.1 - name: nemo_curator.filters.heuristic_filter.BoilerPlateStringFilter - params: + params: remove_if_at_top_or_bottom: True max_boilerplate_string_ratio: 0.4 - name: nemo_curator.filters.heuristic_filter.RepeatedLinesFilter @@ -46,7 +46,7 @@ filters: params: max_num_sentences_without_endmark_ratio: 0.85 - name: nemo_curator.filters.heuristic_filter.WordsWithoutAlphabetsFilter - params: + params: min_words_with_alphabets: 0.8 - name: nemo_curator.filters.heuristic_filter.CommonEnglishWordsFilter params: @@ -54,10 +54,10 @@ filters: stop_at_false: True - name: nemo_curator.filters.heuristic_filter.MeanWordLengthFilter params: - max_mean_word_length: 10 + max_mean_word_length: 10 min_mean_word_length: 3 - name: nemo_curator.filters.heuristic_filter.LongWordFilter - params: + params: max_word_length: 1000 - name: nemo_curator.filters.heuristic_filter.EllipsisFilter params: @@ -102,4 +102,4 @@ filters: max_repeating_duplicate_ngram_ratio: 0.10 - name: nemo_curator.filters.heuristic_filter.BulletsFilter params: - max_bullet_lines_ratio: 0.9 \ No newline at end of file + max_bullet_lines_ratio: 0.9 diff --git a/config/heuristic_filter_non-en.yaml b/config/heuristic_filter_non-en.yaml index 783d0e54..7c456fb2 100644 --- a/config/heuristic_filter_non-en.yaml +++ b/config/heuristic_filter_non-en.yaml @@ -1,7 +1,7 @@ input_field: text filters: # The filters below define a chain of heuristic filters to be applied to each document in a corpus. - # This particular cascade of filters is intended to filter generic non-English data that use spaces for separating words. + # This particular cascade of filters is intended to filter generic non-English data that use spaces for separating words. # The filter listed at the top will be applied first, and the following filters will be applied in # the order they appear in this file. Each filter can be removed and re-ordered as desired. - name: nemo_curator.filters.heuristic_filter.SymbolsToWordsFilter @@ -11,16 +11,16 @@ filters: params: max_number_to_text_ratio: 0.15 - name: nemo_curator.filters.heuristic_filter.UrlsFilter - params: + params: max_url_to_text_ratio: 0.2 - name: nemo_curator.filters.heuristic_filter.WhiteSpaceFilter - params: + params: max_white_space_ratio: 0.25 - name: nemo_curator.filters.heuristic_filter.ParenthesesFilter - params: + params: max_parentheses_ratio: 0.1 - name: nemo_curator.filters.heuristic_filter.BoilerPlateStringFilter - params: + params: remove_if_at_top_or_bottom: True max_boilerplate_string_ratio: 0.4 - name: nemo_curator.filters.heuristic_filter.RepeatedLinesFilter @@ -39,17 +39,17 @@ filters: params: min_words: 50 max_words: 100000 - # NOTE: This filter tends to remove many documents and will need to + # NOTE: This filter tends to remove many documents and will need to # be tuned per language - name: nemo_curator.filters.heuristic_filter.PunctuationFilter params: max_num_sentences_without_endmark_ratio: 0.85 - name: nemo_curator.filters.heuristic_filter.MeanWordLengthFilter params: - max_mean_word_length: 10 + max_mean_word_length: 10 min_mean_word_length: 3 - name: nemo_curator.filters.heuristic_filter.LongWordFilter - params: + params: max_word_length: 1000 - name: nemo_curator.filters.heuristic_filter.EllipsisFilter params: @@ -94,4 +94,4 @@ filters: max_repeating_duplicate_ngram_ratio: 0.10 - name: nemo_curator.filters.heuristic_filter.BulletsFilter params: - max_bullet_lines_ratio: 0.9 \ No newline at end of file + max_bullet_lines_ratio: 0.9 diff --git a/config/lm_tasks.yaml b/config/lm_tasks.yaml index 3d38ec6f..d108ee2d 100644 --- a/config/lm_tasks.yaml +++ b/config/lm_tasks.yaml @@ -1,6 +1,6 @@ tasks: # The Python modules below define language model downstream evaluation - # task data. If one of the below tasks is specified, N-grams will + # task data. If one of the below tasks is specified, N-grams will # be constructed from the documents that make up the task data # using the script prepare_task_data. # find_matching_ngrams will then search for these N-grams diff --git a/config/pii_config.yaml b/config/pii_config.yaml index 725fde30..a693fa78 100644 --- a/config/pii_config.yaml +++ b/config/pii_config.yaml @@ -13,4 +13,4 @@ pii_config: #type: 'hash' #hash_type: 'sha256' - #type: 'redact' \ No newline at end of file + #type: 'redact' diff --git a/config/wikipedia_builder.yaml b/config/wikipedia_builder.yaml index 47831537..87215501 100644 --- a/config/wikipedia_builder.yaml +++ b/config/wikipedia_builder.yaml @@ -12,4 +12,4 @@ format: id: str url: str language: str - source_id: str \ No newline at end of file + source_id: str diff --git a/docs/user-guide/CPUvsGPU.rst b/docs/user-guide/CPUvsGPU.rst index c3159b21..5fd901d1 100644 --- a/docs/user-guide/CPUvsGPU.rst +++ b/docs/user-guide/CPUvsGPU.rst @@ -95,4 +95,4 @@ Every SLURM cluster is different, so make sure you understand how your SLURM clu ``start-slurm.sh`` calls ``containter-entrypoint.sh`` which sets up a Dask scheduler and workers across the cluster. Our Python examples are designed to work such that they can be run locally on their own, or easily substituted into the ``start-slurm.sh`` to run on multiple nodes. -You can adapt your scripts easily too by simply following the pattern of adding ``get_client`` with ``add_distributed_args``. \ No newline at end of file +You can adapt your scripts easily too by simply following the pattern of adding ``get_client`` with ``add_distributed_args``. diff --git a/docs/user-guide/DistributedDataClassification.rst b/docs/user-guide/DistributedDataClassification.rst index b7a99a20..f2bf098d 100644 --- a/docs/user-guide/DistributedDataClassification.rst +++ b/docs/user-guide/DistributedDataClassification.rst @@ -8,7 +8,7 @@ Background When preparing text data to be used in training a large language model (LLM), it is useful to classify text documents in various ways, to enhance the LLM's performance by making it able to produce more -contextually appropriate and accurate language across various subjects. NeMo Curator provides this module to +contextually appropriate and accurate language across various subjects. NeMo Curator provides this module to help a user run inference with pre-trained models on large amounts of text documents. We achieve this by chunking the datasets across multiple computing nodes, each equipped with multiple GPUs, to accelerate the classification task in a distributed way. In other words, because the classification of @@ -68,4 +68,4 @@ The key differences is that it operates on the GPU instead of the CPU. Therefore, the Dask cluster must be started as a GPU one. And, ``DomainClassifier`` requires ``DocumentDataset`` to be on the GPU (i.e., have ``backend=cudf``). It is easy to extend ``DistributedDataClassifier`` to your own model. -Check out ``nemo_curator.modules.distributed_data_classifier.py`` for reference. \ No newline at end of file +Check out ``nemo_curator.modules.distributed_data_classifier.py`` for reference. diff --git a/docs/user-guide/DocumentDataset.rst b/docs/user-guide/DocumentDataset.rst index 8711227a..351e41a9 100644 --- a/docs/user-guide/DocumentDataset.rst +++ b/docs/user-guide/DocumentDataset.rst @@ -48,7 +48,7 @@ You could read, filter the dataset, and write it using the following methods text_field="text", score_field="word_count", ) - + long_books = filter_step(books) long_books.to_json("long_books/", write_to_filename=True) @@ -106,7 +106,7 @@ Consider a modified version of the code above: text_field="text", score_field="word_count", ) - + long_books = filter_step(books) long_books.to_json("long_books/", write_to_filename=True) @@ -130,10 +130,10 @@ In these cases, we recommend processing the input dataset in batches using a sim text_field="text", score_field="word_count", ) - + long_books = filter_step(books) long_books.to_json("long_books/", write_to_filename=True) This will read in 64 shards at a time, process them, and write them back to disk. -Like ``get_remaining_files``, it only includes files that are in the input directory and not in the output directory. \ No newline at end of file +Like ``get_remaining_files``, it only includes files that are in the input directory and not in the output directory. diff --git a/docs/user-guide/Download.rst b/docs/user-guide/Download.rst index 66a34463..e2142de7 100644 --- a/docs/user-guide/Download.rst +++ b/docs/user-guide/Download.rst @@ -91,7 +91,7 @@ datasets. In general, it can be called as follows in order to download and extra --builder-config-file= \ --output-json-dir= -This utility takes as input a list of URLs that point to files that contain prepared, unextracted data (e.g., pre-crawled web pages from Common Crawl), a config file that describes how to download and extract the data, and the output directory to where the extracted text will be written in jsonl format (one json written to each document per line). For each URL provided in the list of URLs, a corresponding jsonl file will be written to the output directory. +This utility takes as input a list of URLs that point to files that contain prepared, unextracted data (e.g., pre-crawled web pages from Common Crawl), a config file that describes how to download and extract the data, and the output directory to where the extracted text will be written in jsonl format (one json written to each document per line). For each URL provided in the list of URLs, a corresponding jsonl file will be written to the output directory. The config file that must be provided at runtime, should take the following form @@ -133,7 +133,7 @@ If you would prefer to use this over `wget ` Downloading and Extracting Common Crawl ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -As described in the first section of this document, the first step towards using the :code:`download_and_extract` for Common Crawl will be to create a list of URLs that point to the location of the WARC files hosted by Common Crawl. +As described in the first section of this document, the first step towards using the :code:`download_and_extract` for Common Crawl will be to create a list of URLs that point to the location of the WARC files hosted by Common Crawl. Within NeMo Curator, we provide the utility :code:`get_common_crawl_urls` to obtain these urls. This utility can be run as follows .. code-block:: bash @@ -144,9 +144,9 @@ Within NeMo Curator, we provide the utility :code:`get_common_crawl_urls` to obt --ending-snapshot="2020-50" \ --output-warc-url-file=./url_data/warc_urls_cc_2020_50.txt -This script pulls the Common Crawl index from `https://index.commoncrawl.org` and stores the index to the file -specified by the argument :code:`--cc-snapshot-index-file`. It then retrieves all WARC urls between the -dates specified by the arguments :code:`--starting-snapshot` and :code:`--ending-snapshot`. +This script pulls the Common Crawl index from `https://index.commoncrawl.org` and stores the index to the file +specified by the argument :code:`--cc-snapshot-index-file`. It then retrieves all WARC urls between the +dates specified by the arguments :code:`--starting-snapshot` and :code:`--ending-snapshot`. Finally, it writes all WARC urls to the text file :code:`--output-warc-urls`. This file is a simple text file with the following format:: @@ -175,16 +175,15 @@ example of a single line of an output `.jsonl` file extracted from a WARC record .. code-block:: json - {"text": "커뮤니티\n\n어린이 요리 교실은 평소 조리와 제과 제빵에 관심이 있는 초등학생을 대상으로 나이프스킬, 한식, 중식, 양식, 제과, 제빵, 디저트, - 생활요리 등 요리 기초부터 시작해 다양한 요리에 대해 배우고, 경험할 수 있도록 구성되었다.\n\n요즘 부모들의 자녀 요리 교육에 대한 관심이 높아지고 - 있는데, 어린이 요리교실은 자녀들이 어디서 어떻게 요리를 처음 시작할지 막막하고 어려워 고민하는 이들을 위해 만들어졌다.\n\n그 뿐만 아니라 학생들이 - 식재료를 다루는 과정에서 손으로 만지고 느끼는 것이 감각을 자극하여 두뇌발달에 도움을 주며, 조리를 통해 자신의 감정을 자연스럽게 표현할 수 - 있고 이를 통해 정서적 안정을 얻을 수 있다. 또한, 다양한 사물을 만져 보면서 차이점을 구별하고 사물의 특징에 대해 인지할 수 있으므로 인지 능력 향상에 - 도움이 되며, 만지고 느끼고 비교하는 과정에서 감각 기능을 향상시킬 수 있다.\n\n방과 후 시간이 되지 않는 초등학생들을 위해 평일반 뿐만 아니라 주말반도 - 운영하고 있으며 두 분의 선생님들의 안전적인 지도하에 수업이 진행된다. 한국조리예술학원은 젊은 감각과 학생들과의 소통을 통해 자발적인 교육을 가르친다. - 자세한 학원 문의는 한국조리예술학원 홈페이지나 대표 전화, 카카오톡 플러스친구를 통해 가능하다.", "id": "a515a7b6-b6ec-4bed-998b-8be2f86f8eac", - "source_id": "https://data.commoncrawl.org/crawl-data/CC-MAIN-2020-50/segments/1606141163411.0/warc/CC-MAIN-20201123153826-20201123183826-00000.warc.gz", + {"text": "커뮤니티\n\n어린이 요리 교실은 평소 조리와 제과 제빵에 관심이 있는 초등학생을 대상으로 나이프스킬, 한식, 중식, 양식, 제과, 제빵, 디저트, + 생활요리 등 요리 기초부터 시작해 다양한 요리에 대해 배우고, 경험할 수 있도록 구성되었다.\n\n요즘 부모들의 자녀 요리 교육에 대한 관심이 높아지고 + 있는데, 어린이 요리교실은 자녀들이 어디서 어떻게 요리를 처음 시작할지 막막하고 어려워 고민하는 이들을 위해 만들어졌다.\n\n그 뿐만 아니라 학생들이 + 식재료를 다루는 과정에서 손으로 만지고 느끼는 것이 감각을 자극하여 두뇌발달에 도움을 주며, 조리를 통해 자신의 감정을 자연스럽게 표현할 수 + 있고 이를 통해 정서적 안정을 얻을 수 있다. 또한, 다양한 사물을 만져 보면서 차이점을 구별하고 사물의 특징에 대해 인지할 수 있으므로 인지 능력 향상에 + 도움이 되며, 만지고 느끼고 비교하는 과정에서 감각 기능을 향상시킬 수 있다.\n\n방과 후 시간이 되지 않는 초등학생들을 위해 평일반 뿐만 아니라 주말반도 + 운영하고 있으며 두 분의 선생님들의 안전적인 지도하에 수업이 진행된다. 한국조리예술학원은 젊은 감각과 학생들과의 소통을 통해 자발적인 교육을 가르친다. + 자세한 학원 문의는 한국조리예술학원 홈페이지나 대표 전화, 카카오톡 플러스친구를 통해 가능하다.", "id": "a515a7b6-b6ec-4bed-998b-8be2f86f8eac", + "source_id": "https://data.commoncrawl.org/crawl-data/CC-MAIN-2020-50/segments/1606141163411.0/warc/CC-MAIN-20201123153826-20201123183826-00000.warc.gz", "url": "http://hanjowon.co.kr/web/home.php?mid=70&go=pds.list&pds_type=1&start=20&num=67&s_key1=&s_que=", "language": "KOREAN"} Once all records have been processed within a WARC file, it is by default deleted from disk. - diff --git a/docs/user-guide/GpuDeduplication.rst b/docs/user-guide/GpuDeduplication.rst index d23e8ee7..d8b54811 100644 --- a/docs/user-guide/GpuDeduplication.rst +++ b/docs/user-guide/GpuDeduplication.rst @@ -10,7 +10,7 @@ Background ----------------------------------------- Training on randomly selected documents for many epochs can be sub-optimal to downstream performance for language models. -For more information on when this is harmful, please see `Muennighoff et al., 2023 `_ and `Tirumala et al., 2023 `_. +For more information on when this is harmful, please see `Muennighoff et al., 2023 `_ and `Tirumala et al., 2023 `_. The exact and fuzzy document-level deduplication module in the NeMo Curator aims at reducing the occurence of duplicate and near-duplicate documents in the dataset. Exact deduplication refers to removing identical (i.e., document strings are equal) documents from the dataset, while fuzzy deduplication refers to removing near-identical (e.g., an excerpt of a document is used in another document) @@ -27,7 +27,7 @@ As exact deduplication is a much less involved procedure and requires significan we typically will first run exact deduplication before fuzzy deduplication. Also, from our experience in deduplicating Common Crawl snapshots, a significant portion of the duplicates are in fact exact duplicates. -When removing near-duplicates within the corpus we perform fuzzy deduplication at the document level in order to remove documents that +When removing near-duplicates within the corpus we perform fuzzy deduplication at the document level in order to remove documents that have high Jaccard similarity. Our approach closely resembles the approach described in `Smith et al., 2020 `_. This approach can essentially be split into two conceptual changes. The first stage involves computing MinHashes Signatures on documents and then performing Locality Sensitive Hashing (LSH) to find candidate duplucates. Due to the approximate nature of the bucketing via MinHash + LSH @@ -35,11 +35,11 @@ documents and then performing Locality Sensitive Hashing (LSH) to find candidate -Before running either of these modules, users should assign a unique document ID to each document in the corpus. +Before running either of these modules, users should assign a unique document ID to each document in the corpus. This can be accomplished using the :code:`add_id` module within the NeMo Curator: .. code-block:: bash - + add_id \ --input-data-dir= \ --log-dir=./log/add_id @@ -47,7 +47,7 @@ This can be accomplished using the :code:`add_id` module within the NeMo Curator By default, this will create a new field named :code:`adlr_id` within each json document which will have the form "doc_prefix-000001". If the dataset already has a unique ID this step can be skipped. -**Note**: Fuzzy deduplication only works with numeric ID's or the specific ID format generated by the :code:`add_id` script. If the +**Note**: Fuzzy deduplication only works with numeric ID's or the specific ID format generated by the :code:`add_id` script. If the dataset does not contain ID's in this format it's recommended to convert to an integer based ID or ID created by the :code:`add_id` script. Once a unique ID has been added to each document, users can proceed with exact and fuzzy deduplication which roughly require the following @@ -80,4 +80,3 @@ steps (all scripts are included in the :code:`nemo_curator/scripts/` subdirector In addition to the scripts, there are examples in the `examples` directory that showcase using the python module directly in your own code. It also has examples on how to remove documents from the corpus using the list of duplicate IDs generated from exact or fuzzy deduplication. - diff --git a/docs/user-guide/LanguageIdentificationUnicodeFormatting.rst b/docs/user-guide/LanguageIdentificationUnicodeFormatting.rst index ddd107bf..3e61f8f7 100644 --- a/docs/user-guide/LanguageIdentificationUnicodeFormatting.rst +++ b/docs/user-guide/LanguageIdentificationUnicodeFormatting.rst @@ -40,7 +40,7 @@ Here is the implementation of the ``UnicodeReformatter`` modifier: class UnicodeReformatter(DocumentModifier): def __init__(self): super().__init__() - + def modify_document(self, text: str) -> str: return ftfy.fix_text(text) @@ -51,7 +51,7 @@ Related Scripts ----------------------------------------- To perform the language identification, we can use the config file provided in the `config` directory -and provide the path to a local copy of the `lid.176.bin` language identification fastText model. Then, with the general purpose +and provide the path to a local copy of the `lid.176.bin` language identification fastText model. Then, with the general purpose :code:`filter_documents` tool, we can compute language scores and codes for each document in the corpus as follows .. code-block:: bash @@ -77,7 +77,7 @@ within that file. Below is an example run command for :code:`separate_by_metadat --input-metadata-field=language \ --output-data-dir= \ --output-metadata-distribution=./data/lang_distro.json - + After running this module, the output directory will consist of one directory per language present within the corpus and all documents within those directories will contain text that originates from the same language. Finally, the text within a specific language can have its unicode fixed using the :code:`text_cleaning` module diff --git a/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst b/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst index 9e05b688..3f2ebcc6 100644 --- a/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst +++ b/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst @@ -101,4 +101,3 @@ Resuming from Interruptions It can be helpful to track which documents in a dataset have already been processed so that long curation jobs can be resumed if they are interrupted. NeMo Curator provides a utility for easily tracking which dataset shards have already been processed. A call to ``get_batched_files`` will return an iterator over the files that have yet to be processed by a modifier such as ``PiiModifierBatched`` When you re-run the code example provided above, NeMo Curator ensures that only unprocessed files are processed by the PII module. - diff --git a/docs/user-guide/QualityFiltering.rst b/docs/user-guide/QualityFiltering.rst index 6d044e22..46a8c9d8 100644 --- a/docs/user-guide/QualityFiltering.rst +++ b/docs/user-guide/QualityFiltering.rst @@ -43,7 +43,7 @@ Let's examine this small example: text_field="text", score_field="word_count", ) - + long_books = filter_step(books) long_books.to_json("long_books/", write_to_filename=True) @@ -101,7 +101,7 @@ For example, if the dataset in the above example came pre-populated with the ``w WordCountFilter(min_words=80).keep_document, filter_field="word_count", ) - + long_books = filter_step(books) long_books.to_json("long_books/", write_to_filename=True) @@ -117,7 +117,7 @@ Alternatively, if you simply want to track the length of the words in the docume text_field="text", score_field="word_count", ) - + annotated_books = filter_step(books) annotated_books.to_json("annotated_books/", write_to_filename=True) @@ -161,15 +161,15 @@ Classifier Filtering The classifier-based filtering approach we have implemented follows closely to that used in `Brown et al., 2020 `_, and trains a binary skip-gram classifier that can be used to distinguish between low and high quality documents. To implement this, we use the -functions provided by fastText. Following the examples provided in the fastText documentation, we first create a file consisting of +functions provided by fastText. Following the examples provided in the fastText documentation, we first create a file consisting of high and low-quality training documents. We provide an example of how to train and use a model in ``examples/classifier_filtering.py``. We also provide CLI scripts for the same functionality. The :code:`prepare_fasttext_training_data` script will randomly sample documents -from an input dataset and will prepare them to be used to train a fasText skip-gram classifier. For a high-quality dataset we recommend sampling from +from an input dataset and will prepare them to be used to train a fasText skip-gram classifier. For a high-quality dataset we recommend sampling from either OpenWebText2 or Wikipedia and an unfiltered version of Common Crawl can be used for a low-quality dataset. .. code-block:: bash - + prepare_fasttext_training_data \ --input-data-dir= \ --output-num-samples= \ @@ -183,12 +183,12 @@ either OpenWebText2 or Wikipedia and an unfiltered version of Common Crawl can b --output-train-file=${res_dir}/hq_samples.txt \ Once the samples have been prepared and written to :code:`.txt` files, users can use the :code:`train_fasttext` script that reads in the samples within the :code:`.txt` files -in order to train a quality classifier. :code:`train_fasttext` will read in all of the samples within the :code:`.txt` files, split the data into training and +in order to train a quality classifier. :code:`train_fasttext` will read in all of the samples within the :code:`.txt` files, split the data into training and validation sets and train the binary skip-gram classifier. After training, it evaluates the model on the validation samples and writes the predictions to a jsonl file prints the confusion matrix to stdout. .. code-block:: bash - + train_fasttext \ --fasttext-files-dir=${res_dir} \ --output-train-file=${res_dir}/fasttext_samples.train \ @@ -202,7 +202,7 @@ be used for classifier-based quality filtering with a fastText model. Additional as is described in `Brown et al., 2020 `_. .. code-block:: bash - + filter_documents \ --input-data-dir= \ --filter-config-file=./config/fasttext_quality_filter.yaml \ @@ -236,7 +236,7 @@ The filters are general enough that users should feel free to remove certain fil with the results of different filter configurations/parameters. Additionally, these filters have been used for curating high-quality non-English documents. However, it is advised that when applying -to non-English data that users write out the document scores by specifying the :code:`--document-score-dir` argument. This will allow users to +to non-English data that users write out the document scores by specifying the :code:`--document-score-dir` argument. This will allow users to examine if a particular filter is responsible for undesirably removing many documents from a corpus. .. code-block:: bash diff --git a/docs/user-guide/TaskDecontamination.rst b/docs/user-guide/TaskDecontamination.rst index 25bd6547..46a0d980 100644 --- a/docs/user-guide/TaskDecontamination.rst +++ b/docs/user-guide/TaskDecontamination.rst @@ -38,7 +38,7 @@ Let's examine this small example: ] task_decontaminate = nc.TaskDecontamination(downstream_tasks) - + decontaminated_books = task_decontaminate(books) decontaminated_books.to_json("decontaminated_books/", write_to_filename=True) diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index e4cbd75e..278e47ab 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -1,7 +1,7 @@ .. include:: DataCuration.rsts :ref:`Downloading and Extracting Text ` - Downloading a massive public dataset is usually the first step in data curation, and it can be cumbersome due to the dataset’s massive size and hosting method. This section describes how to download and extract large corpora efficiently. + Downloading a massive public dataset is usually the first step in data curation, and it can be cumbersome due to the dataset’s massive size and hosting method. This section describes how to download and extract large corpora efficiently. :ref:`Working with DocumentDataset ` DocumentDataset is the standard format for datasets in NeMo Curator. This section describes how to get datasets in and out of this format, as well as how DocumentDataset interacts with the modules. @@ -16,7 +16,7 @@ Large, unlabeled text corpora often contain a variety of languages. The NeMo Curator provides utilities to identify languages and fix improperly decoded Unicode characters. :ref:`GPU Accelerated Exact and Fuzzy Deduplication ` - Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF. + Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF. :ref:`Classifier and Heuristic Quality Filtering ` Classifier-based filtering involves training a small text classifer to label a document as either high quality or low quality. Heuristic-based filtering uses simple rules (e.g. Are there too many punctuation marks?) to score a document. NeMo Curator offers both classifier and heuristic-based quality filtering of documents. @@ -40,4 +40,4 @@ GpuDeduplication.rst TaskDecontamination.rst PersonalIdentifiableInformationIdentificationAndRemoval.rst - DistributedDataClassification.rst \ No newline at end of file + DistributedDataClassification.rst diff --git a/examples/classifier_filtering.py b/examples/classifier_filtering.py index 6c3588c7..dbe24c45 100644 --- a/examples/classifier_filtering.py +++ b/examples/classifier_filtering.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import fasttext -import random import argparse +import random + +import fasttext import nemo_curator as nc from nemo_curator.datasets import DocumentDataset -from nemo_curator.modifiers import FastTextLabelModifier from nemo_curator.filters import BatchedFastTextQualityFilter +from nemo_curator.modifiers import FastTextLabelModifier +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import get_all_files_paths_under -from nemo_curator.utils.distributed_utils import read_data, write_to_disk, get_client from nemo_curator.utils.script_utils import add_distributed_args @@ -32,15 +33,19 @@ def load_dataset(input_data_dir): return dataset + def create_samples(data_path, label, num_samples): raw_dataset = load_dataset(data_path) label_quality = nc.Modify(FastTextLabelModifier(label)) labeled_dataset = label_quality(raw_dataset) - labeled_samples = labeled_dataset.df.sample(frac=num_samples / len(labeled_dataset.df)) + labeled_samples = labeled_dataset.df.sample( + frac=num_samples / len(labeled_dataset.df) + ) return labeled_samples["text"].compute().values.tolist() + def main(args): # Params low_quality_data_path = "/path/to/low_quality" @@ -51,8 +56,12 @@ def main(args): # Prepare samples for the classifier client = get_client(args, args.device) - low_quality_samples = create_samples(low_quality_data_path, "__label__lq", num_low_quality_samples) - high_quality_samples = create_samples(high_quality_data_path, "__label__hq", num_high_quality_samples) + low_quality_samples = create_samples( + low_quality_data_path, "__label__lq", num_low_quality_samples + ) + high_quality_samples = create_samples( + high_quality_data_path, "__label__hq", num_high_quality_samples + ) train_samples = low_quality_samples + high_quality_samples random.shuffle(train_samples) @@ -61,7 +70,7 @@ def main(args): with open(train_file, "w") as f: for sample in train_samples: f.write(sample) - f.write('\n') + f.write("\n") # Train fastText classifier model = fasttext.train_supervised( @@ -75,14 +84,25 @@ def main(args): # Filter data target_dataset = load_dataset(low_quality_data_path) - filter_pipeline = nc.ScoreFilter(BatchedFastTextQualityFilter(model_path), score_field="quality_score", batched=True, score_type=float) + filter_pipeline = nc.ScoreFilter( + BatchedFastTextQualityFilter(model_path), + score_field="quality_score", + batched=True, + score_type=float, + ) filtered_dataset = filter_pipeline(target_dataset) # Write filtered dataset write_to_disk(filtered_dataset.df, filtered_output, write_to_filename=True) -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): return add_distributed_args(parser) + if __name__ == "__main__": - main(attach_args().parse_args()) \ No newline at end of file + main(attach_args().parse_args()) diff --git a/examples/distributed_data_classification_examples/domain_api_example.py b/examples/distributed_data_classification_examples/domain_api_example.py index 55d44e3a..ad2fa1c8 100644 --- a/examples/distributed_data_classification_examples/domain_api_example.py +++ b/examples/distributed_data_classification_examples/domain_api_example.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os import time -import argparse - from nemo_curator import DomainClassifier from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.distributed_utils import get_client @@ -73,9 +72,7 @@ def main(args): ) result_dataset = domain_classifier(dataset=input_dataset) - result_dataset.to_json( - output_file_dir=output_file_path, write_to_filename=True - ) + result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True) global_et = time.time() print( @@ -85,7 +82,12 @@ def main(args): client.close() -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): parser.add_argument( "--scheduler-address", type=str, @@ -132,7 +134,7 @@ def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.Argument default="gpu", help="Device to run the script on. Either 'cpu' or 'gpu'.", ) - + return parser diff --git a/examples/distributed_data_classification_examples/quality_api_example.py b/examples/distributed_data_classification_examples/quality_api_example.py index 65e28d18..53b9849c 100644 --- a/examples/distributed_data_classification_examples/quality_api_example.py +++ b/examples/distributed_data_classification_examples/quality_api_example.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os import time -import argparse - from nemo_curator import QualityClassifier from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.distributed_utils import get_client @@ -54,7 +53,12 @@ def main(args): client.close() -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): parser.add_argument( "--scheduler-address", type=str, @@ -101,7 +105,7 @@ def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.Argument default="gpu", help="Device to run the script on. Either 'cpu' or 'gpu'.", ) - + return parser diff --git a/examples/download_arxiv.py b/examples/download_arxiv.py index 945457aa..a20d36bf 100644 --- a/examples/download_arxiv.py +++ b/examples/download_arxiv.py @@ -36,8 +36,14 @@ def main(args): # Inspect the samples print(sample.compute()) -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): return add_distributed_args(parser) + if __name__ == "__main__": - main(attach_args().parse_args()) \ No newline at end of file + main(attach_args().parse_args()) diff --git a/examples/download_common_crawl.py b/examples/download_common_crawl.py index 2ded1532..103e88c5 100644 --- a/examples/download_common_crawl.py +++ b/examples/download_common_crawl.py @@ -32,14 +32,22 @@ def main(args): client = get_client(args, args.device) # Download and sample data - common_crawl = download_common_crawl(output_directory, start_snapshot, end_snapshot, url_limit=url_limit) + common_crawl = download_common_crawl( + output_directory, start_snapshot, end_snapshot, url_limit=url_limit + ) sample = common_crawl.df.sample(frac=10 / len(common_crawl)) # Inspect the samples print(sample.compute()) -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): return add_distributed_args(parser) + if __name__ == "__main__": - main(attach_args().parse_args()) \ No newline at end of file + main(attach_args().parse_args()) diff --git a/examples/download_wikipedia.py b/examples/download_wikipedia.py index 377e2ee5..f6e42852 100644 --- a/examples/download_wikipedia.py +++ b/examples/download_wikipedia.py @@ -31,14 +31,22 @@ def main(args): client = get_client(args, args.device) # Download and sample data - wikipedia = download_wikipedia(output_directory, dump_date=dump_date, url_limit=url_limit) + wikipedia = download_wikipedia( + output_directory, dump_date=dump_date, url_limit=url_limit + ) sample = wikipedia.df.sample(frac=10 / len(wikipedia)) # Inspect the samples print(sample.compute()) -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): return add_distributed_args(parser) + if __name__ == "__main__": - main(attach_args().parse_args()) \ No newline at end of file + main(attach_args().parse_args()) diff --git a/examples/exact_deduplication.py b/examples/exact_deduplication.py index 2cd50bb2..b722bf58 100644 --- a/examples/exact_deduplication.py +++ b/examples/exact_deduplication.py @@ -13,13 +13,14 @@ # limitations under the License. import argparse +import time from nemo_curator.datasets import DocumentDataset from nemo_curator.modules import ExactDuplicates from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import get_all_files_paths_under from nemo_curator.utils.script_utils import add_distributed_args -import time + def pre_imports(): import cudf # noqa: F401 @@ -65,10 +66,12 @@ def main(args): # When there are few duplicates we can compute the results to a list and use `isin`. result = input_dataset.df[ - ~input_dataset.df[dataset_id_field].isin(docs_to_remove[dataset_id_field].compute()) + ~input_dataset.df[dataset_id_field].isin( + docs_to_remove[dataset_id_field].compute() + ) ] write_to_disk(result, output_dir, output_type="parquet") - print(time.time()-t0) + print(time.time() - t0) def attach_args( diff --git a/examples/find_pii_and_deidentify.py b/examples/find_pii_and_deidentify.py index 0ac1eb56..9cb6d9bd 100644 --- a/examples/find_pii_and_deidentify.py +++ b/examples/find_pii_and_deidentify.py @@ -29,21 +29,24 @@ def console_script(): arguments = add_distributed_args(parser).parse_args() _ = get_client(arguments, arguments.device) - dataframe = pd.DataFrame({'text': ['Sarah and Ryan went out to play', 'Jensen is the CEO of NVIDIA']}) + dataframe = pd.DataFrame( + {"text": ["Sarah and Ryan went out to play", "Jensen is the CEO of NVIDIA"]} + ) dd = dask.dataframe.from_pandas(dataframe, npartitions=1) dataset = DocumentDataset(dd) modifier = PiiModifierBatched( - log_dir='./logs', + log_dir="./logs", batch_size=2000, - language='en', - supported_entities=['PERSON', "EMAIL_ADDRESS"], - anonymize_action='replace') + language="en", + supported_entities=["PERSON", "EMAIL_ADDRESS"], + anonymize_action="replace", + ) modify = Modify(modifier, batched=True) modified_dataset = modify(dataset) - modified_dataset.df.to_json('output_files/*.jsonl', lines=True, orient='records') + modified_dataset.df.to_json("output_files/*.jsonl", lines=True, orient="records") if __name__ == "__main__": - console_script() \ No newline at end of file + console_script() diff --git a/examples/identify_languages_and_fix_unicode.py b/examples/identify_languages_and_fix_unicode.py index b9fbe353..933c6c23 100644 --- a/examples/identify_languages_and_fix_unicode.py +++ b/examples/identify_languages_and_fix_unicode.py @@ -19,8 +19,11 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.filters import FastTextLangId from nemo_curator.modifiers import UnicodeReformatter -from nemo_curator.utils.file_utils import get_all_files_paths_under, separate_by_metadata -from nemo_curator.utils.distributed_utils import read_data, write_to_disk, get_client +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.file_utils import ( + get_all_files_paths_under, + separate_by_metadata, +) from nemo_curator.utils.script_utils import add_distributed_args @@ -31,6 +34,7 @@ def load_dataset(input_data_dir): return dataset + def main(args): # Params multilingual_data_path = "/path/to/multilingual" @@ -39,7 +43,7 @@ def main(args): # Download a fastText language identification model # and see a list of supported languages here: - # https://fasttext.cc/docs/en/language-identification.html + # https://fasttext.cc/docs/en/language-identification.html model_path = "/path/to/model.bin" target_language = "EN" language_field = "language" @@ -49,14 +53,22 @@ def main(args): # Filter data multilingual_dataset = load_dataset(multilingual_data_path) - language_id_pipeline = nc.ScoreFilter(FastTextLangId(model_path), score_field=language_field, score_type='object') + language_id_pipeline = nc.ScoreFilter( + FastTextLangId(model_path), score_field=language_field, score_type="object" + ) filtered_dataset = language_id_pipeline(multilingual_dataset) - + # Remove the language score - filtered_dataset.df[language_field] = filtered_dataset.df[language_field].apply(lambda score: score[1]) + filtered_dataset.df[language_field] = filtered_dataset.df[language_field].apply( + lambda score: score[1] + ) # Split the dataset by language - language_stats = separate_by_metadata(filtered_dataset.df, language_separated_output_path, metadata_field=language_field).compute() + language_stats = separate_by_metadata( + filtered_dataset.df, + language_separated_output_path, + metadata_field=language_field, + ).compute() # Read the language specific data and fix the unicode in it lang_data_path = os.path.join(language_separated_output_path, target_language) @@ -70,8 +82,14 @@ def main(args): # Write the cleaned_data write_to_disk(cleaned_data.df, cleaned_data_output_path, write_to_filename=True) -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): return add_distributed_args(parser) + if __name__ == "__main__": - main(attach_args().parse_args()) \ No newline at end of file + main(attach_args().parse_args()) diff --git a/examples/raw_download_common_crawl.py b/examples/raw_download_common_crawl.py index ade3fd67..ab40eaf8 100644 --- a/examples/raw_download_common_crawl.py +++ b/examples/raw_download_common_crawl.py @@ -14,11 +14,11 @@ import argparse -from nemo_curator.download import batch_download, CommonCrawlWARCDownloader -from nemo_curator.utils.download_utils import get_common_crawl_urls +from nemo_curator.download import CommonCrawlWARCDownloader, batch_download from nemo_curator.utils.distributed_utils import get_client -from nemo_curator.utils.script_utils import add_distributed_args +from nemo_curator.utils.download_utils import get_common_crawl_urls from nemo_curator.utils.file_utils import expand_outdir_and_mkdir +from nemo_curator.utils.script_utils import add_distributed_args def main(args): @@ -45,8 +45,14 @@ def main(args): for file in output_files: print(file) -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): return add_distributed_args(parser) + if __name__ == "__main__": - main(attach_args().parse_args()) \ No newline at end of file + main(attach_args().parse_args()) diff --git a/examples/slurm/container-entrypoint.sh b/examples/slurm/container-entrypoint.sh index d3d49b17..8bc6a9a3 100755 --- a/examples/slurm/container-entrypoint.sh +++ b/examples/slurm/container-entrypoint.sh @@ -67,4 +67,4 @@ fi while [ ! -f $DONE_MARKER ] do sleep 15 -done \ No newline at end of file +done diff --git a/examples/slurm/start-slurm.sh b/examples/slurm/start-slurm.sh index 4454655d..02c211f6 100644 --- a/examples/slurm/start-slurm.sh +++ b/examples/slurm/start-slurm.sh @@ -80,4 +80,4 @@ mkdir -p $PROFILESDIR srun \ --container-mounts=${MOUNTS} \ --container-image=${CONTAINER_IMAGE} \ - ${CONTAINER_ENTRYPOINT} \ No newline at end of file + ${CONTAINER_ENTRYPOINT} diff --git a/examples/task_decontamination.py b/examples/task_decontamination.py index 91730455..8e2ae4cb 100644 --- a/examples/task_decontamination.py +++ b/examples/task_decontamination.py @@ -17,29 +17,29 @@ import nemo_curator as nc from nemo_curator.datasets import DocumentDataset from nemo_curator.tasks import ( - Winogrande, - Squad, - TriviaQA, - Quac, - WebQA, - Race, - Drop, - WiC, + ANLI, + CB, PIQA, - ArcEasy, + RTE, + WSC, ArcChallenge, - OpenBookQA, + ArcEasy, BoolQ, Copa, - RTE, + Drop, MultiRC, - WSC, - CB, - ANLI, - Record + OpenBookQA, + Quac, + Race, + Record, + Squad, + TriviaQA, + WebQA, + WiC, + Winogrande, ) +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import get_all_files_paths_under -from nemo_curator.utils.distributed_utils import read_data, write_to_disk, get_client from nemo_curator.utils.script_utils import add_distributed_args @@ -50,6 +50,7 @@ def load_dataset(input_data_dir): return dataset + def main(args): # Params contaminated_dataset_path = "/path/to/input" @@ -75,7 +76,7 @@ def main(args): WSC(), CB(), ANLI(), - Record() + Record(), ] # Prepare samples for the classifier @@ -87,10 +88,18 @@ def main(args): decontaminated_dataset = decontaminator(target_dataset) # Write filtered dataset - write_to_disk(decontaminated_dataset.df, decontaminated_output_path, write_to_filename=True) + write_to_disk( + decontaminated_dataset.df, decontaminated_output_path, write_to_filename=True + ) + -def attach_args(parser=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)): +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): return add_distributed_args(parser) + if __name__ == "__main__": - main(attach_args().parse_args()) \ No newline at end of file + main(attach_args().parse_args()) diff --git a/nemo_curator/__init__.py b/nemo_curator/__init__.py index 26205067..000e459a 100644 --- a/nemo_curator/__init__.py +++ b/nemo_curator/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .modules import * \ No newline at end of file +from .modules import * diff --git a/nemo_curator/datasets/__init__.py b/nemo_curator/datasets/__init__.py index 276cbc10..af9695b2 100644 --- a/nemo_curator/datasets/__init__.py +++ b/nemo_curator/datasets/__init__.py @@ -14,4 +14,4 @@ from .doc_dataset import DocumentDataset -__all__ = ["DocumentDataset"] \ No newline at end of file +__all__ = ["DocumentDataset"] diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index 396aa9b6..af45f290 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dask_cudf import dask.dataframe as dd +import dask_cudf from nemo_curator.utils.distributed_utils import read_data, write_to_disk from nemo_curator.utils.file_utils import get_all_files_paths_under @@ -24,6 +24,7 @@ class DocumentDataset: A collection of documents and document metadata. Internally it may be distributed across multiple nodes, and may be on GPUs. """ + def __init__(self, dataset_df): self.df = dataset_df @@ -41,13 +42,15 @@ def read_json( files_per_partition=1, add_filename=False, ): - return cls(_read_json_or_parquet( - input_files=input_files, - file_type="jsonl", - backend=backend, - files_per_partition=files_per_partition, - add_filename=add_filename, - )) + return cls( + _read_json_or_parquet( + input_files=input_files, + file_type="jsonl", + backend=backend, + files_per_partition=files_per_partition, + add_filename=add_filename, + ) + ) @classmethod def read_parquet( @@ -57,13 +60,15 @@ def read_parquet( files_per_partition=1, add_filename=False, ): - return cls(_read_json_or_parquet( - input_files=input_files, - file_type="parquet", - backend=backend, - files_per_partition=files_per_partition, - add_filename=add_filename, - )) + return cls( + _read_json_or_parquet( + input_files=input_files, + file_type="parquet", + backend=backend, + files_per_partition=files_per_partition, + add_filename=add_filename, + ) + ) @classmethod def read_pickle( diff --git a/nemo_curator/distributed_data_classification/__init__.py b/nemo_curator/distributed_data_classification/__init__.py index fe99e99a..d9155f92 100644 --- a/nemo_curator/distributed_data_classification/__init__.py +++ b/nemo_curator/distributed_data_classification/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/nemo_curator/distributed_data_classification/generate_statistics.py b/nemo_curator/distributed_data_classification/generate_statistics.py index 5dae3ef9..17e76c92 100644 --- a/nemo_curator/distributed_data_classification/generate_statistics.py +++ b/nemo_curator/distributed_data_classification/generate_statistics.py @@ -27,7 +27,7 @@ def value_counts(df, column_name): """ This function groups a DataFrame by the specified column and counts the occurrences of each group. It is essentially the same as pandas.Series.value_counts, except it returns a DataFrame. - + Args: df: A DataFrame. column_name: The column by which to group the DataFrame. @@ -58,12 +58,16 @@ def main(): print("Starting statistics workflow", flush=True) st = time.time() - df = read_data( - input_files=get_all_files_paths_under(args.input_file_path, recurse_subdirecties=False), + df = read_data( + input_files=get_all_files_paths_under( + args.input_file_path, recurse_subdirecties=False + ), file_type=args.input_file_type, add_filename=True, - ) - input_files = get_all_files_paths_under(args.input_file_path, recurse_subdirecties=False) + ) + input_files = get_all_files_paths_under( + args.input_file_path, recurse_subdirecties=False + ) result = value_counts(df, column_name=args.label) result = result.rename(columns={0: "count"}) diff --git a/nemo_curator/download/__init__.py b/nemo_curator/download/__init__.py index 7b5d1387..3328b1c1 100644 --- a/nemo_curator/download/__init__.py +++ b/nemo_curator/download/__init__.py @@ -12,9 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .doc_builder import DocumentDownloader, DocumentIterator, DocumentExtractor, download_and_extract, import_downloader, import_extractor, import_iterator, batch_download -from .commoncrawl import download_common_crawl, CommonCrawlWARCDownloader, CommonCrawlWARCExtractor, CommonCrawlWARCIterator, CommonCrawlWARCDownloaderExtractOnly -from .wikipedia import download_wikipedia, WikipediaDownloader, WikipediaIterator, WikipediaExtractor -from .arxiv import download_arxiv, ArxivDownloader, ArxivIterator, ArxivExtractor +from .arxiv import ArxivDownloader, ArxivExtractor, ArxivIterator, download_arxiv +from .commoncrawl import ( + CommonCrawlWARCDownloader, + CommonCrawlWARCDownloaderExtractOnly, + CommonCrawlWARCExtractor, + CommonCrawlWARCIterator, + download_common_crawl, +) +from .doc_builder import ( + DocumentDownloader, + DocumentExtractor, + DocumentIterator, + batch_download, + download_and_extract, + import_downloader, + import_extractor, + import_iterator, +) +from .wikipedia import ( + WikipediaDownloader, + WikipediaExtractor, + WikipediaIterator, + download_wikipedia, +) -__all__ = ["DocumentDownloader", "DocumentIterator", "DocumentExtractor", "download_and_extract", "import_downloader", "import_extractor", "import_iterator", "download_common_crawl", "CommonCrawlWARCDownloader", "CommonCrawlWARCExtractor", "CommonCrawlWARCIterator", "CommonCrawlWARCDownloaderExtractOnly", "download_wikipedia", "WikipediaDownloader", "WikipediaIterator", "WikipediaExtractor", "batch_download", "download_arxiv", "ArxivDownloader", "ArxivIterator", "ArxivExtractor"] \ No newline at end of file +__all__ = [ + "DocumentDownloader", + "DocumentIterator", + "DocumentExtractor", + "download_and_extract", + "import_downloader", + "import_extractor", + "import_iterator", + "download_common_crawl", + "CommonCrawlWARCDownloader", + "CommonCrawlWARCExtractor", + "CommonCrawlWARCIterator", + "CommonCrawlWARCDownloaderExtractOnly", + "download_wikipedia", + "WikipediaDownloader", + "WikipediaIterator", + "WikipediaExtractor", + "batch_download", + "download_arxiv", + "ArxivDownloader", + "ArxivIterator", + "ArxivExtractor", +] diff --git a/nemo_curator/download/arxiv.py b/nemo_curator/download/arxiv.py index d9556db8..68001696 100644 --- a/nemo_curator/download/arxiv.py +++ b/nemo_curator/download/arxiv.py @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gzip import os import re -import gzip -import tempfile -import tarfile import subprocess +import tarfile +import tempfile from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.file_utils import get_all_files_paths_under, expand_outdir_and_mkdir from nemo_curator.download.doc_builder import ( DocumentDownloader, - DocumentIterator, DocumentExtractor, - download_and_extract + DocumentIterator, + download_and_extract, ) from nemo_curator.utils.download_utils import get_arxiv_urls +from nemo_curator.utils.file_utils import ( + expand_outdir_and_mkdir, + get_all_files_paths_under, +) # The iterator and extractor code are in large part taken # from the Red-Pajama repo @@ -36,375 +39,398 @@ class ArxivDownloader(DocumentDownloader): - def __init__(self, download_dir, verbose=False): - super().__init__() - self._download_dir = download_dir - self._verbose = False - - def download(self, tarfile): - output_file = os.path.join(self._download_dir, tarfile) - s3path = os.path.join("s3://arxiv/src", tarfile) - if os.path.exists(output_file): - print(f"tar file: {output_file} exists. Not downloading") - else: - print(f"Downloading {s3path} and writing to {output_file}") - cmd = ["s5cmd", "--request-payer=requester", "cp", s3path, output_file] - if self._verbose: - stdout, stderr = None, None - else: - stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL - p = subprocess.run( - cmd, - stdout=stdout, - stderr=stderr, - ) - if p.returncode != 0: - print(f"Failed to download {s3path} to {output_file}") - - return output_file + def __init__(self, download_dir, verbose=False): + super().__init__() + self._download_dir = download_dir + self._verbose = False + + def download(self, tarfile): + output_file = os.path.join(self._download_dir, tarfile) + s3path = os.path.join("s3://arxiv/src", tarfile) + if os.path.exists(output_file): + print(f"tar file: {output_file} exists. Not downloading") + else: + print(f"Downloading {s3path} and writing to {output_file}") + cmd = ["s5cmd", "--request-payer=requester", "cp", s3path, output_file] + if self._verbose: + stdout, stderr = None, None + else: + stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL + p = subprocess.run( + cmd, + stdout=stdout, + stderr=stderr, + ) + if p.returncode != 0: + print(f"Failed to download {s3path} to {output_file}") + + return output_file class ArxivIterator(DocumentIterator): - def __init__(self, log_frequency=1000): - super().__init__() - self._log_frequency = log_frequency - self._counter = 0 - - def iterate(self, file_path): - self._counter = 0 - download_dir = os.path.split(file_path)[0] - bname = os.path.split(file_path)[-1] - with tempfile.TemporaryDirectory(dir=download_dir) as tmpdir: - with tarfile.open(file_path) as tf: - tf.extractall(members=tf.getmembers(), path=tmpdir) - for i, item in enumerate(get_all_files_paths_under(tmpdir)): - if self._counter > 0 and self._counter % self._log_frequency == 0: - print(f"Extracted {self._counter} papers from {file_path}") - self._counter += 1 - - tex_files = self._tex_proj_loader(item) - arxiv_id = os.path.splitext(os.path.split(item)[-1])[0] - - # get the arxiv id in the correct format - try: - clean_arxiv_id = self._format_arxiv_id(arxiv_id) - except Exception as e: - print(f"[WARNING] failed to format arxiv id " - f"{arxiv_id}; exception={e}") - clean_arxiv_id = arxiv_id - - if tex_files is None: - continue - - yield {'id': clean_arxiv_id, 'source_id': f'{bname}'}, tex_files - - def _tex_proj_loader(self, file_or_dir_path): - r""" function to load the tex files from a tar file or a gzip file. The - function will return a tuple containing a list of tex files and the - timestamp of the project. - - @param file_or_dir_path: path to the tar file or the gzip file - - @return: tuple containing a list of tex files and the timestamp of the - project - """ - files_and_content = [] - - try: - # if it is a directory, open it as a tarfile - with tarfile.open(file_or_dir_path) as sub_tf: - for member in sub_tf.getmembers(): - if member.name.endswith(".tex"): - - file_content = sub_tf.extractfile(member).read() + def __init__(self, log_frequency=1000): + super().__init__() + self._log_frequency = log_frequency + self._counter = 0 + + def iterate(self, file_path): + self._counter = 0 + download_dir = os.path.split(file_path)[0] + bname = os.path.split(file_path)[-1] + with tempfile.TemporaryDirectory(dir=download_dir) as tmpdir: + with tarfile.open(file_path) as tf: + tf.extractall(members=tf.getmembers(), path=tmpdir) + for i, item in enumerate(get_all_files_paths_under(tmpdir)): + if self._counter > 0 and self._counter % self._log_frequency == 0: + print(f"Extracted {self._counter} papers from {file_path}") + self._counter += 1 + + tex_files = self._tex_proj_loader(item) + arxiv_id = os.path.splitext(os.path.split(item)[-1])[0] + + # get the arxiv id in the correct format + try: + clean_arxiv_id = self._format_arxiv_id(arxiv_id) + except Exception as e: + print( + f"[WARNING] failed to format arxiv id " + f"{arxiv_id}; exception={e}" + ) + clean_arxiv_id = arxiv_id + + if tex_files is None: + continue + + yield {"id": clean_arxiv_id, "source_id": f"{bname}"}, tex_files + + def _tex_proj_loader(self, file_or_dir_path): + r"""function to load the tex files from a tar file or a gzip file. The + function will return a tuple containing a list of tex files and the + timestamp of the project. + + @param file_or_dir_path: path to the tar file or the gzip file + + @return: tuple containing a list of tex files and the timestamp of the + project + """ + files_and_content = [] + + try: + # if it is a directory, open it as a tarfile + with tarfile.open(file_or_dir_path) as sub_tf: + for member in sub_tf.getmembers(): + if member.name.endswith(".tex"): + + file_content = sub_tf.extractfile(member).read() + + try: + file_content = file_content.decode("utf-8") + except UnicodeDecodeError: + # self._logger.info(f"UnicodeDecodeError: {file_or_dir_path}") + return None + + files_and_content.append(file_content) + + except tarfile.ReadError: + # otherwise we try opening it as a gzip file + try: + with gzip.open(file_or_dir_path, "rb") as gz: + file_content = gz.read() + except Exception: + # all fails, we skip this file + # self._logger.info(f"[ERROR] {e}: {file_or_dir_path}") + return None try: - file_content = file_content.decode("utf-8") + file_content = file_content.decode("utf-8") except UnicodeDecodeError: - # self._logger.info(f"UnicodeDecodeError: {file_or_dir_path}") - return None + # self._logger.info(f"UnicodeDecodeError: {file_or_dir_path}") + return None files_and_content.append(file_content) - except tarfile.ReadError: - # otherwise we try opening it as a gzip file - try: - with gzip.open(file_or_dir_path, "rb") as gz: - file_content = gz.read() - except Exception: - # all fails, we skip this file - # self._logger.info(f"[ERROR] {e}: {file_or_dir_path}") - return None + except Exception as e: + print(f"[ERROR] {e}: {file_or_dir_path}") + return None - try: - file_content = file_content.decode("utf-8") - except UnicodeDecodeError: - # self._logger.info(f"UnicodeDecodeError: {file_or_dir_path}") - return None + return files_and_content - files_and_content.append(file_content) + def _format_arxiv_id(self, arxiv_id): + r"""this function brings the raw arxiv-id into a format compliant with the + specification from arxiv. This is used to create the url to the arxiv + abstract page. - except Exception as e: - print(f"[ERROR] {e}: {file_or_dir_path}") - return None + - Format prior to March 2007: + /YYMMNNN where N is a 3-digit number + - Format after March 2007: /YYMM.NNNNN where N is a + 5 (or 6)-digit number - return files_and_content + References: https://info.arxiv.org/help/arxiv_identifier.html - def _format_arxiv_id(self, arxiv_id): - r""" this function brings the raw arxiv-id into a format compliant with the - specification from arxiv. This is used to create the url to the arxiv - abstract page. + @param arxiv_id: raw arxiv id which can be in one of the + following formats: + - + - - - Format prior to March 2007: - /YYMMNNN where N is a 3-digit number - - Format after March 2007: /YYMM.NNNNN where N is a - 5 (or 6)-digit number + @return: formatted arxiv id + """ + match = re.search(r"^([a-zA-Z-]*)([\d\.]+)$", arxiv_id) - References: https://info.arxiv.org/help/arxiv_identifier.html + if match is None: + raise ValueError(f"Invalid arxiv id: {arxiv_id}") - @param arxiv_id: raw arxiv id which can be in one of the - following formats: - - - - + if match.group(1) == "": + return match.group(2) - @return: formatted arxiv id - """ - match = re.search(r'^([a-zA-Z-]*)([\d\.]+)$', arxiv_id) - - if match is None: - raise ValueError(f"Invalid arxiv id: {arxiv_id}") - - if match.group(1) == "": - return match.group(2) - - return f"{match.group(1)}/{match.group(2)}" + return f"{match.group(1)}/{match.group(2)}" class ArxivExtractor(DocumentExtractor): - def __init__(self): - super().__init__() - - def extract(self, content): - r""" function takes a list of tex files and returns a cleaned version of - the tex project. The cleaned version is a concatenation of the tex files - with the following modifications: - - - if multiple latex files, then concatenate them - - remove all comments (i.e. all lines starting with %) - - remove everything before the first \section header - - remove everything after the first occurrence of either \appendix or - \bibliography - - inline-expand definitions and macros - - @param tex_files: list of file_content strings - - @return: cleaned tex project as a string, empty string - if no tex files are provided - """ - if len(content) == 0: - return None - - # build dictionaries that contain the definitions of all macros in all tex - # files. This is later used to expand all macros used in the text with - # their definitions, so that consistency among different authors is - # ensured. - - non_arg_macros = {} - for file_content in content: - non_arg_macros.update(self._build_non_arg_macros_dict(file_content)) - - # TODO: macros that take arguments are not supported yet - arg_macros = {} - - # join multiple latex files with a newline character - try: - cleaned_latex_file_str = "\n".join( - self._clean_tex_file( - file_content=file_content, - arg_macros=arg_macros, - non_arg_macros=non_arg_macros, - ) for file_content in content) - except Exception: - return {}, None - - # Don't return meta - if cleaned_latex_file_str is not None: - if len(cleaned_latex_file_str) > 0: - return {}, cleaned_latex_file_str - - def _clean_tex_file(self, file_content, arg_macros, non_arg_macros): - r""" function takes a tex file as input and returns a cleaned version. The - cleaned version is a concatenation of the tex files with the - following modifications: - - - remove all comments (i.e. all lines starting with %) - - remove everything before the first section-like header - - remove everything after the first occurrence of either \appendix or - \bibliography - - inline-expand definitions and macros - - @param file_content: the content of the tex file as a string. - - @return: cleaned tex file as a string - """ - # find the first occurence of a \section-like header and replace everything - # before it with an empty string. This matches the following pattern: - # \[optional-args]{name} - pattern = r"^(.*?)(" - pattern += r"\\\bchapter\b\*?(?:\[(.*?)\])?\{(.*?)\}|" - pattern += r"\\\bpart\b\*?(?:\[(.*?)\])?\{(.*?)\}|" - pattern += r"\\\bsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|" - pattern += r"\\\bsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|" - pattern += r"\\\bsubsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|" - pattern += r"\\\bparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}" - pattern += r"\\\bsubparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}" - pattern += r")" - - # if no section like header is found, then we return an empty string - if not re.search(pattern, file_content, flags=re.DOTALL): - return "" - - # replace everything with the second group of the match (i.e. everything - # after and including the section header) - file_content = re.sub( - pattern=pattern, - repl=r"\2", - string=file_content, - flags=re.DOTALL # make sure that the dot matches also newlines - ) - - # remove all line comments - file_content = re.sub( - pattern=r"(?m)^%.*\n?", - repl=r"", - string=file_content, - flags=re.MULTILINE, - ) - - # remove all in comments within a line - file_content = re.sub( - # pattern matches a "%" that is not preceded by a backslash (=comment) - pattern=r"[^\\]%.+$", - repl=r"", - string=file_content, - flags=re.MULTILINE) - - # find the first occurence of either \appendix or \bibliography and - # replace everything after it with an empty string - pattern = r"(" - pattern += r"\\appendix|" - pattern += r"\\begin\{references\}|" - pattern += r"\\begin\{REFERENCES\}|" - pattern += r"\\begin\{thebibliography\}|" - pattern += r"\\bibliography\{.*\}" - pattern += r").*$" - - file_content = re.sub( - pattern=pattern, - repl=r'', - string=file_content, - flags=re.DOTALL # make sure that the dot matches also newlines - ) - - # inline-expand all non-arg macros - for macro_name, macro_value in non_arg_macros.items(): - file_content = re.sub( - # make pattern grouped to make sure that the macro is not part - # of a longer alphanumeric word - pattern=r"(" + macro_name + r")" + r"([^a-zA-Z0-9])", - # replace the macro with its value and add back the character that - # was matched after the macro - repl=macro_value + r"\2", - string=file_content, - ) - - # inline-expand all macros that use args - # TODO: inline-expand macros with args - for macro_name, macro_value in arg_macros.items(): - pass - - return file_content - - def _build_non_arg_macros_dict(self, file_content): - r""" function takes the content of a tex file and returns a dictionary - that contains the definitions of all macros that do not use arguments. - The dictionary is of the form {macro_name: macro_value}. - - @param file_content: the content of the tex file as a string. - - @return: dict - """ - # regex for extracting \newcommand macros without arguments - non_arg_nc_reg = re.compile( - # this regex matches the following: - # \newcommand{\macro_name}{macro_value} - # \newcommand*{\macro_name}{macro_value} - # where macro_name is only allowed to contain letters and numbers; - # macro_value can contain any character. - pattern=r'\\\bnewcommand\b\*?\{(\\[a-zA-Z0-9]+?)\}\{(.*?)\}$', - flags=re.MULTILINE, + def __init__(self): + super().__init__() + + def extract(self, content): + r"""function takes a list of tex files and returns a cleaned version of + the tex project. The cleaned version is a concatenation of the tex files + with the following modifications: + + - if multiple latex files, then concatenate them + - remove all comments (i.e. all lines starting with %) + - remove everything before the first \section header + - remove everything after the first occurrence of either \appendix or + \bibliography + - inline-expand definitions and macros + + @param tex_files: list of file_content strings + + @return: cleaned tex project as a string, empty string + if no tex files are provided + """ + if len(content) == 0: + return None + + # build dictionaries that contain the definitions of all macros in all tex + # files. This is later used to expand all macros used in the text with + # their definitions, so that consistency among different authors is + # ensured. + + non_arg_macros = {} + for file_content in content: + non_arg_macros.update(self._build_non_arg_macros_dict(file_content)) + + # TODO: macros that take arguments are not supported yet + arg_macros = {} + + # join multiple latex files with a newline character + try: + cleaned_latex_file_str = "\n".join( + self._clean_tex_file( + file_content=file_content, + arg_macros=arg_macros, + non_arg_macros=non_arg_macros, + ) + for file_content in content + ) + except Exception: + return {}, None + + # Don't return meta + if cleaned_latex_file_str is not None: + if len(cleaned_latex_file_str) > 0: + return {}, cleaned_latex_file_str + + def _clean_tex_file(self, file_content, arg_macros, non_arg_macros): + r"""function takes a tex file as input and returns a cleaned version. The + cleaned version is a concatenation of the tex files with the + following modifications: + + - remove all comments (i.e. all lines starting with %) + - remove everything before the first section-like header + - remove everything after the first occurrence of either \appendix or + \bibliography + - inline-expand definitions and macros + + @param file_content: the content of the tex file as a string. + + @return: cleaned tex file as a string + """ + # find the first occurence of a \section-like header and replace everything + # before it with an empty string. This matches the following pattern: + # \[optional-args]{name} + pattern = r"^(.*?)(" + pattern += r"\\\bchapter\b\*?(?:\[(.*?)\])?\{(.*?)\}|" + pattern += r"\\\bpart\b\*?(?:\[(.*?)\])?\{(.*?)\}|" + pattern += r"\\\bsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|" + pattern += r"\\\bsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|" + pattern += r"\\\bsubsubsection\b\*?(?:\[(.*?)\])?\{(.*?)\}|" + pattern += r"\\\bparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}" + pattern += r"\\\bsubparagraph\b\*?(?:\[(.*?)\])?\{(.*?)\}" + pattern += r")" + + # if no section like header is found, then we return an empty string + if not re.search(pattern, file_content, flags=re.DOTALL): + return "" + + # replace everything with the second group of the match (i.e. everything + # after and including the section header) + file_content = re.sub( + pattern=pattern, + repl=r"\2", + string=file_content, + flags=re.DOTALL, # make sure that the dot matches also newlines + ) + + # remove all line comments + file_content = re.sub( + pattern=r"(?m)^%.*\n?", + repl=r"", + string=file_content, + flags=re.MULTILINE, + ) + + # remove all in comments within a line + file_content = re.sub( + # pattern matches a "%" that is not preceded by a backslash (=comment) + pattern=r"[^\\]%.+$", + repl=r"", + string=file_content, + flags=re.MULTILINE, + ) + + # find the first occurence of either \appendix or \bibliography and + # replace everything after it with an empty string + pattern = r"(" + pattern += r"\\appendix|" + pattern += r"\\begin\{references\}|" + pattern += r"\\begin\{REFERENCES\}|" + pattern += r"\\begin\{thebibliography\}|" + pattern += r"\\bibliography\{.*\}" + pattern += r").*$" + + file_content = re.sub( + pattern=pattern, + repl=r"", + string=file_content, + flags=re.DOTALL, # make sure that the dot matches also newlines + ) + + # inline-expand all non-arg macros + for macro_name, macro_value in non_arg_macros.items(): + file_content = re.sub( + # make pattern grouped to make sure that the macro is not part + # of a longer alphanumeric word + pattern=r"(" + macro_name + r")" + r"([^a-zA-Z0-9])", + # replace the macro with its value and add back the character that + # was matched after the macro + repl=macro_value + r"\2", + string=file_content, + ) + + # inline-expand all macros that use args + # TODO: inline-expand macros with args + for macro_name, macro_value in arg_macros.items(): + pass + + return file_content + + def _build_non_arg_macros_dict(self, file_content): + r"""function takes the content of a tex file and returns a dictionary + that contains the definitions of all macros that do not use arguments. + The dictionary is of the form {macro_name: macro_value}. + + @param file_content: the content of the tex file as a string. + + @return: dict + """ + # regex for extracting \newcommand macros without arguments + non_arg_nc_reg = re.compile( + # this regex matches the following: + # \newcommand{\macro_name}{macro_value} + # \newcommand*{\macro_name}{macro_value} + # where macro_name is only allowed to contain letters and numbers; + # macro_value can contain any character. + pattern=r"\\\bnewcommand\b\*?\{(\\[a-zA-Z0-9]+?)\}\{(.*?)\}$", + flags=re.MULTILINE, + ) + + # regex for extracting \def macros without arguments + non_arg_def_reg = re.compile( + # this regex matches the following: + # \def\macro_name{macro_value} + # where macro_name is only allowed to contain letters and numbers; + # macro_value can contain any character. + pattern=r"\\def\s*(\\[a-zA-Z0-9]+?)\s*\{(.*?)\}$", + flags=re.MULTILINE, + ) + + # Extract all user-defined LaTeX macros from the preamble + macros = {} + for reg in [non_arg_nc_reg, non_arg_def_reg]: + for match in reg.finditer(file_content): + # convert the macro name and value to a raw string that can be + # used in re.sub + macro_name = match.group(1).encode("unicode-escape").decode("utf-8") + macro_val = match.group(2).encode("unicode-escape").decode("utf-8") + + macros[macro_name] = macro_val + + return macros + + +def download_arxiv( + output_path: str, + output_type: str = "jsonl", + raw_download_dir=None, + keep_raw_download=False, + force_download=False, + url_limit=None, +) -> DocumentDataset: + """ + Downloads Arxiv tar files and extracts them + + Args: + output_path: The path to the root directory of the files + output_type: The file type to save the data as. + raw_download_dir: Path to store the raw download files for intermediate processing. + If None, they are stored in a folder named "downloads" under output_path. + keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. + force_download: If False, will skip processing all files in output_paths that already exist and + directly read from them instead. + url_limit: The maximum number of raw files to download from the snapshot. If None, all + files from the range of snapshots are downloaded. + """ + arxiv_urls = get_arxiv_urls() + if url_limit: + arxiv_urls = arxiv_urls[:url_limit] + output_paths = list( + map(lambda url: os.path.join(output_path, f"{url}.{output_type}"), arxiv_urls) ) - # regex for extracting \def macros without arguments - non_arg_def_reg = re.compile( - # this regex matches the following: - # \def\macro_name{macro_value} - # where macro_name is only allowed to contain letters and numbers; - # macro_value can contain any character. - pattern=r'\\def\s*(\\[a-zA-Z0-9]+?)\s*\{(.*?)\}$', - flags=re.MULTILINE, + if not raw_download_dir: + raw_download_dir = os.path.join(output_path, "downloads") + expand_outdir_and_mkdir(raw_download_dir) + downloader = ArxivDownloader(raw_download_dir) + iterator = ArxivIterator() + extractor = ArxivExtractor() + + output_format = { + "text": str, + "id": str, + "source_id": str, + "filename": str, + } + dataset = download_and_extract( + arxiv_urls, + output_paths, + downloader, + iterator, + extractor, + output_format, + output_type=output_type, + keep_raw_download=keep_raw_download, + force_download=force_download, ) - # Extract all user-defined LaTeX macros from the preamble - macros = {} - for reg in [non_arg_nc_reg, non_arg_def_reg]: - for match in reg.finditer(file_content): - # convert the macro name and value to a raw string that can be - # used in re.sub - macro_name = match \ - .group(1).encode("unicode-escape").decode("utf-8") - macro_val = match \ - .group(2).encode("unicode-escape").decode("utf-8") - - macros[macro_name] = macro_val - - return macros - -def download_arxiv(output_path: str, output_type: str="jsonl", raw_download_dir=None, keep_raw_download=False, force_download=False, url_limit=None) -> DocumentDataset: - """ - Downloads Arxiv tar files and extracts them - - Args: - output_path: The path to the root directory of the files - output_type: The file type to save the data as. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. - """ - arxiv_urls = get_arxiv_urls() - if url_limit: - arxiv_urls = arxiv_urls[:url_limit] - output_paths = list(map(lambda url: os.path.join(output_path, f"{url}.{output_type}"), arxiv_urls)) - - if not raw_download_dir: - raw_download_dir = os.path.join(output_path, "downloads") - expand_outdir_and_mkdir(raw_download_dir) - downloader = ArxivDownloader(raw_download_dir) - iterator = ArxivIterator() - extractor = ArxivExtractor() - - output_format = { - "text": str, - "id": str, - "source_id": str, - "filename": str, - } - dataset = download_and_extract(arxiv_urls, output_paths, downloader, iterator, extractor, output_format, output_type=output_type, keep_raw_download=keep_raw_download, force_download=force_download) - - return dataset \ No newline at end of file + return dataset diff --git a/nemo_curator/download/commoncrawl.py b/nemo_curator/download/commoncrawl.py index 8051db20..15736ee9 100644 --- a/nemo_curator/download/commoncrawl.py +++ b/nemo_curator/download/commoncrawl.py @@ -15,54 +15,57 @@ import os import subprocess -import pycld2 as cld2 +import unicodedata +from urllib.parse import urlparse + import justext import lxml -import unicodedata +import pycld2 as cld2 from charset_normalizer import detect -from urllib.parse import urlparse from warcio.archiveiterator import ArchiveIterator +from nemo_curator.datasets import DocumentDataset from nemo_curator.download.doc_builder import ( DocumentDownloader, - DocumentIterator, DocumentExtractor, + DocumentIterator, download_and_extract, ) from nemo_curator.utils.download_utils import get_common_crawl_urls -from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.file_utils import expand_outdir_and_mkdir def decode_html(html_bytes): - # Convert from bytes to text using utf-8 encoding - try: - return html_bytes.decode('utf-8') - except UnicodeDecodeError: - # If utf-8 fails, try to find a different encoding - return try_decode_with_detected_encoding(html_bytes) + # Convert from bytes to text using utf-8 encoding + try: + return html_bytes.decode("utf-8") + except UnicodeDecodeError: + # If utf-8 fails, try to find a different encoding + return try_decode_with_detected_encoding(html_bytes) def try_decode_with_detected_encoding(html_bytes): - detected_encoding = detect(html_bytes)['encoding'] - bad_detection = not detected_encoding or detected_encoding == 'utf-8' - if bad_detection: - return None - try: - return html_bytes.decode(detected_encoding) - except: - return None + detected_encoding = detect(html_bytes)["encoding"] + bad_detection = not detected_encoding or detected_encoding == "utf-8" + if bad_detection: + return None + try: + return html_bytes.decode(detected_encoding) + except: + return None def lang_detect(decoded_html): - try: - details = cld2.detect(decoded_html)[2] - except Exception: - # Remove control characters - cleaned_html = ''.join(i for i in decoded_html if unicodedata.category(i)[0] != 'C') - details = cld2.detect(cleaned_html)[2] + try: + details = cld2.detect(decoded_html)[2] + except Exception: + # Remove control characters + cleaned_html = "".join( + i for i in decoded_html if unicodedata.category(i)[0] != "C" + ) + details = cld2.detect(cleaned_html)[2] - return details[0][0].upper() + return details[0][0].upper() def extract_text( @@ -77,244 +80,275 @@ def extract_text( no_headings=False, logger=None, ): - # Segment the HTML into paragraphs - try: - # Form the DOM tree - dom = justext.core.html_to_dom(html) - cleaned_dom = justext.core.preprocessor(dom) - # Get the paragraphs from the DOM - handler = justext.core.ParagraphMaker() - lxml.sax.saxify(cleaned_dom, handler) - except (lxml.etree.ParserError, ValueError, Exception): - # Return nothing when we cannot segment the document - if logger is not None: - logger.info("Could not segment paragaphs in the document") - return - paragraphs = handler.paragraphs - - # Context free classification - justext.core.classify_paragraphs( - paragraphs, - stop_words, - length_low, - length_high, - stopwords_low, - stopwords_high, - max_link_density, - no_headings, - ) - - # Copy the context free class to the class_style - # This handles the headings as described in the - # documentation - for paragraph in paragraphs: - paragraph.class_type = paragraph.cf_class - - # Context sensitive classification - justext.core.revise_paragraph_classification( - paragraphs, - max_heading_distance, - ) - - return [p.text for p in paragraphs if not p.is_boilerplate] + # Segment the HTML into paragraphs + try: + # Form the DOM tree + dom = justext.core.html_to_dom(html) + cleaned_dom = justext.core.preprocessor(dom) + # Get the paragraphs from the DOM + handler = justext.core.ParagraphMaker() + lxml.sax.saxify(cleaned_dom, handler) + except (lxml.etree.ParserError, ValueError, Exception): + # Return nothing when we cannot segment the document + if logger is not None: + logger.info("Could not segment paragaphs in the document") + return + paragraphs = handler.paragraphs + + # Context free classification + justext.core.classify_paragraphs( + paragraphs, + stop_words, + length_low, + length_high, + stopwords_low, + stopwords_high, + max_link_density, + no_headings, + ) + + # Copy the context free class to the class_style + # This handles the headings as described in the + # documentation + for paragraph in paragraphs: + paragraph.class_type = paragraph.cf_class + + # Context sensitive classification + justext.core.revise_paragraph_classification( + paragraphs, + max_heading_distance, + ) + + return [p.text for p in paragraphs if not p.is_boilerplate] def get_stop_list_dict(languages=[]): - # Name mapping for language names from CLD2 (values) - # and jusText (keys) - lang_map = { - 'Haitian': 'HAITIAN_CREOLE', - 'Norwegian_Bokmal': 'NORWEGIAN', - 'Norwegian_Nynorsk': 'NORWEGIAN_N', - 'Waray_Waray': 'WARAY_PHILIPPINES', - } - if len(languages) == 0: - languages = justext.get_stoplists() - # Remove latin as it yields a lot of low quality documents - languages_no_latin = list(languages) - languages_no_latin.remove('Latin') - languages = frozenset(languages_no_latin) - - stop_list_dict = {} - for language in languages: - if language in lang_map: - lang_key = lang_map[language] - else: - lang_key = language.upper() - stop_list_dict[lang_key] = justext.get_stoplist(language) - - return stop_list_dict + # Name mapping for language names from CLD2 (values) + # and jusText (keys) + lang_map = { + "Haitian": "HAITIAN_CREOLE", + "Norwegian_Bokmal": "NORWEGIAN", + "Norwegian_Nynorsk": "NORWEGIAN_N", + "Waray_Waray": "WARAY_PHILIPPINES", + } + if len(languages) == 0: + languages = justext.get_stoplists() + # Remove latin as it yields a lot of low quality documents + languages_no_latin = list(languages) + languages_no_latin.remove("Latin") + languages = frozenset(languages_no_latin) + + stop_list_dict = {} + for language in languages: + if language in lang_map: + lang_key = lang_map[language] + else: + lang_key = language.upper() + stop_list_dict[lang_key] = justext.get_stoplist(language) + + return stop_list_dict def get_all_stop_words(): - stop_words = set() - for language in justext.get_stoplists(): - stop_words.update(justext.get_stoplist(language)) + stop_words = set() + for language in justext.get_stoplists(): + stop_words.update(justext.get_stoplist(language)) - return frozenset(stop_words) + return frozenset(stop_words) class CommonCrawlWARCDownloader(DocumentDownloader): - """ - Downloads WARC files from the Common Crawl - """ - - def __init__(self, download_dir, aws=False, verbose=False): """ - Creates a downloader - - Args: - download_dir: Path to store raw compressed WARC files - aws: If True, uses the s5cmd command to download from the Common Crawl's S3 bucket. - If False, uses wget. - verbose: If True, logs stdout and stderr of the download command (s5cmd/wget) + Downloads WARC files from the Common Crawl """ - super().__init__() - self._download_dir = download_dir - self._aws = aws - self._verbose = verbose - - def download(self, url): - # Download each URL to the directory - urlpath = urlparse(url).path[1:] - output_name = urlpath.replace('/', '-') - output_file = os.path.join(self._download_dir, output_name) - if os.path.exists(output_file): - print(f"WARC file: {output_file} exists. Not downloading") - else: - print(f"Downloading {url} and writing to {output_file}") - # Download with either wget or s5cmd (aws) - if self._aws: - s3path = os.path.join("s3://commoncrawl/", urlpath) - cmd = ["s5cmd", "cp", s3path, output_file] - else: - cmd = ["wget", url, "-O", output_file] - if self._verbose: - stdout, stderr = None, None - else: - stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL - p = subprocess.run( - cmd, - stdout=stdout, - stderr=stderr, - ) - if p.returncode != 0: - print(f"Failed to download {url} to {output_file}") - - return output_file + + def __init__(self, download_dir, aws=False, verbose=False): + """ + Creates a downloader + + Args: + download_dir: Path to store raw compressed WARC files + aws: If True, uses the s5cmd command to download from the Common Crawl's S3 bucket. + If False, uses wget. + verbose: If True, logs stdout and stderr of the download command (s5cmd/wget) + """ + super().__init__() + self._download_dir = download_dir + self._aws = aws + self._verbose = verbose + + def download(self, url): + # Download each URL to the directory + urlpath = urlparse(url).path[1:] + output_name = urlpath.replace("/", "-") + output_file = os.path.join(self._download_dir, output_name) + if os.path.exists(output_file): + print(f"WARC file: {output_file} exists. Not downloading") + else: + print(f"Downloading {url} and writing to {output_file}") + # Download with either wget or s5cmd (aws) + if self._aws: + s3path = os.path.join("s3://commoncrawl/", urlpath) + cmd = ["s5cmd", "cp", s3path, output_file] + else: + cmd = ["wget", url, "-O", output_file] + if self._verbose: + stdout, stderr = None, None + else: + stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL + p = subprocess.run( + cmd, + stdout=stdout, + stderr=stderr, + ) + if p.returncode != 0: + print(f"Failed to download {url} to {output_file}") + + return output_file class CommonCrawlWARCDownloaderExtractOnly(DocumentDownloader): - """ - A 'dummy' downloader that simply puts pre-downloaded - files on the queue - """ + """ + A 'dummy' downloader that simply puts pre-downloaded + files on the queue + """ - def __init__(self, aws=False, verbose=False): - super().__init__() + def __init__(self, aws=False, verbose=False): + super().__init__() - def download(self, url): - print(f"Putting WARC file {url} on the queue for extraction") - return url + def download(self, url): + print(f"Putting WARC file {url} on the queue for extraction") + return url class CommonCrawlWARCIterator(DocumentIterator): - def __init__(self, log_frequency=1000): - super().__init__() - self._counter = 0 - self._log_frequency = log_frequency - - def iterate(self, file_path): - # Loop over all records in the current WARC - self._counter = 0 - bname = os.path.split(file_path)[-1] - with open(file_path, 'rb') as file_pointer: - ai = ArchiveIterator(file_pointer, arc2warc=True) - for k, rec in enumerate(ai): - # Get the response from the crawl - if rec.rec_type == 'response': - if self._counter > 0 and self._counter % self._log_frequency == 0: - print(f"Extracted {self._counter} records in WARC") - self._counter += 1 - content = rec.content_stream().read() - warc_id = rec.rec_headers.get_header('WARC-Record-ID')[10:-1] - url = rec.rec_headers.get_header('WARC-Target-URI') - meta = { - 'url': url, - 'warc_id': warc_id, - 'source_id': f'{bname}', - } - yield meta, content + def __init__(self, log_frequency=1000): + super().__init__() + self._counter = 0 + self._log_frequency = log_frequency + + def iterate(self, file_path): + # Loop over all records in the current WARC + self._counter = 0 + bname = os.path.split(file_path)[-1] + with open(file_path, "rb") as file_pointer: + ai = ArchiveIterator(file_pointer, arc2warc=True) + for k, rec in enumerate(ai): + # Get the response from the crawl + if rec.rec_type == "response": + if self._counter > 0 and self._counter % self._log_frequency == 0: + print(f"Extracted {self._counter} records in WARC") + self._counter += 1 + content = rec.content_stream().read() + warc_id = rec.rec_headers.get_header("WARC-Record-ID")[10:-1] + url = rec.rec_headers.get_header("WARC-Target-URI") + meta = { + "url": url, + "warc_id": warc_id, + "source_id": f"{bname}", + } + yield meta, content class CommonCrawlWARCExtractor(DocumentExtractor): - def __init__(self): - self._stop_lists = get_stop_list_dict() - super().__init__() - - def extract(self, content): - html = decode_html(content) - if html is not None: - # Language detection and HTML extraction - lang = lang_detect(html) - text = None - if lang in self._stop_lists: - text = extract_text(html, self._stop_lists[lang]) - if text is not None: - if len(text) > 0: - text = '\n\n'.join(text) - meta = {'language': lang} - return meta, text - else: - return None, None - -def download_common_crawl(output_path: str, start_snapshot: str, end_snapshot: str, output_type: str="jsonl", news=False, aws=False, raw_download_dir=None, keep_raw_download=False, force_download=False, url_limit=None) -> DocumentDataset: - """ - Downloads Common Crawl WARC snapshots and extracts them using jusText - - Args: - output_path: The path to the root directory of the files - start_snapshot: The first common crawl snapshot to include. Snapshots must be - specified by YYYY-WeekNumber (e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, - (specified with news=True flag) this changes to Year-Month (YYYY-MM). - end_snapshot: The last common crawl snapshot to include. Must be chronologically - after the starting snapshot. - output_type: The file type to save the data as. - news: If True, gets WARC URLs for the CC-NEWS dataset instead of the CC-MAIN datasets. - Also assumes that the format for the start and end snapshots is 'YYYY-MM' (Year-Month). - aws: Whether to download from Common Crawl's S3 bucket. If True, uses s5cmd to download. - If False, uses wget. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. - """ - common_crawl_urls = get_common_crawl_urls(starting_snapshot=start_snapshot, ending_snapshot=end_snapshot, news=news) - if url_limit: - common_crawl_urls = common_crawl_urls[:url_limit] - output_paths = list(map(lambda url: os.path.join(output_path, url.split("/")[-1] + f".{output_type}"), common_crawl_urls)) - - if not raw_download_dir: - raw_download_dir = os.path.join(output_path, "downloads") - expand_outdir_and_mkdir(raw_download_dir) - downloader = CommonCrawlWARCDownloader(raw_download_dir, aws=aws) - iterator = CommonCrawlWARCIterator() - extractor = CommonCrawlWARCExtractor() - - output_format = { - "text": str, - "language": str, - "url": str, - "warc_id": str, - "source_id": str, - "filename": str, - } - dataset = download_and_extract(common_crawl_urls, output_paths, downloader, iterator, extractor, output_format, output_type=output_type, keep_raw_download=keep_raw_download, force_download=force_download) - - return dataset \ No newline at end of file + def __init__(self): + self._stop_lists = get_stop_list_dict() + super().__init__() + + def extract(self, content): + html = decode_html(content) + if html is not None: + # Language detection and HTML extraction + lang = lang_detect(html) + text = None + if lang in self._stop_lists: + text = extract_text(html, self._stop_lists[lang]) + if text is not None: + if len(text) > 0: + text = "\n\n".join(text) + meta = {"language": lang} + return meta, text + else: + return None, None + + +def download_common_crawl( + output_path: str, + start_snapshot: str, + end_snapshot: str, + output_type: str = "jsonl", + news=False, + aws=False, + raw_download_dir=None, + keep_raw_download=False, + force_download=False, + url_limit=None, +) -> DocumentDataset: + """ + Downloads Common Crawl WARC snapshots and extracts them using jusText + + Args: + output_path: The path to the root directory of the files + start_snapshot: The first common crawl snapshot to include. Snapshots must be + specified by YYYY-WeekNumber (e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, + (specified with news=True flag) this changes to Year-Month (YYYY-MM). + end_snapshot: The last common crawl snapshot to include. Must be chronologically + after the starting snapshot. + output_type: The file type to save the data as. + news: If True, gets WARC URLs for the CC-NEWS dataset instead of the CC-MAIN datasets. + Also assumes that the format for the start and end snapshots is 'YYYY-MM' (Year-Month). + aws: Whether to download from Common Crawl's S3 bucket. If True, uses s5cmd to download. + If False, uses wget. + raw_download_dir: Path to store the raw download files for intermediate processing. + If None, they are stored in a folder named "downloads" under output_path. + keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. + force_download: If False, will skip processing all files in output_paths that already exist and + directly read from them instead. + url_limit: The maximum number of raw files to download from the snapshot. If None, all + files from the range of snapshots are downloaded. + """ + common_crawl_urls = get_common_crawl_urls( + starting_snapshot=start_snapshot, ending_snapshot=end_snapshot, news=news + ) + if url_limit: + common_crawl_urls = common_crawl_urls[:url_limit] + output_paths = list( + map( + lambda url: os.path.join( + output_path, url.split("/")[-1] + f".{output_type}" + ), + common_crawl_urls, + ) + ) + + if not raw_download_dir: + raw_download_dir = os.path.join(output_path, "downloads") + expand_outdir_and_mkdir(raw_download_dir) + downloader = CommonCrawlWARCDownloader(raw_download_dir, aws=aws) + iterator = CommonCrawlWARCIterator() + extractor = CommonCrawlWARCExtractor() + + output_format = { + "text": str, + "language": str, + "url": str, + "warc_id": str, + "source_id": str, + "filename": str, + } + dataset = download_and_extract( + common_crawl_urls, + output_paths, + downloader, + iterator, + extractor, + output_format, + output_type=output_type, + keep_raw_download=keep_raw_download, + force_download=force_download, + ) + + return dataset diff --git a/nemo_curator/download/doc_builder.py b/nemo_curator/download/doc_builder.py index 03395e91..8bdf3e30 100644 --- a/nemo_curator/download/doc_builder.py +++ b/nemo_curator/download/doc_builder.py @@ -13,162 +13,195 @@ # limitations under the License. import importlib +import os from abc import ABC, abstractmethod from typing import List, Tuple + import dask.dataframe as dd -from dask import delayed, compute import pandas as pd -import os +from dask import compute, delayed from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.distributed_utils import single_partition_write_with_filename, read_single_partition +from nemo_curator.utils.distributed_utils import ( + read_single_partition, + single_partition_write_with_filename, +) class DocumentDownloader(ABC): - """ Abstract class for downloading remote data to disk """ + """Abstract class for downloading remote data to disk""" - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - @abstractmethod - def download(self, url): - pass + @abstractmethod + def download(self, url): + pass class DocumentIterator(ABC): - """ - Abstract iterator class for reading in raw records that have been - downloaded to disk - """ + """ + Abstract iterator class for reading in raw records that have been + downloaded to disk + """ - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - @abstractmethod - def iterate(self, file_path): - pass + @abstractmethod + def iterate(self, file_path): + pass class DocumentExtractor(ABC): - """ Abstract class for extracting text from records read from disk """ + """Abstract class for extracting text from records read from disk""" - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - @abstractmethod - def extract(self, content): - pass + @abstractmethod + def extract(self, content): + pass def import_downloader(downloader_path): - module_path, downloader_name = downloader_path.rsplit(".", 1) - downloader_module = importlib.import_module(module_path) - downloader_class = getattr(downloader_module, downloader_name) - if not issubclass(downloader_class, DocumentDownloader): - raise ValueError(f"Input downloader {downloader_class.__name__} " - "must be derived from DocumentDownloader defined in " - "nemo_curator.download.docbuilder") - return downloader_class + module_path, downloader_name = downloader_path.rsplit(".", 1) + downloader_module = importlib.import_module(module_path) + downloader_class = getattr(downloader_module, downloader_name) + if not issubclass(downloader_class, DocumentDownloader): + raise ValueError( + f"Input downloader {downloader_class.__name__} " + "must be derived from DocumentDownloader defined in " + "nemo_curator.download.docbuilder" + ) + return downloader_class def import_iterator(iterator_path): - module_path, iterator_name = iterator_path.rsplit(".", 1) - iterator_module = importlib.import_module(module_path) - iterator_class = getattr(iterator_module, iterator_name) - if not issubclass(iterator_class, DocumentIterator): - raise ValueError(f"Input iterator {iterator_class.__name__} " - "must be derived from DocumentIterator " - "defined in nemo_curator.download.docbuilder") - return iterator_class + module_path, iterator_name = iterator_path.rsplit(".", 1) + iterator_module = importlib.import_module(module_path) + iterator_class = getattr(iterator_module, iterator_name) + if not issubclass(iterator_class, DocumentIterator): + raise ValueError( + f"Input iterator {iterator_class.__name__} " + "must be derived from DocumentIterator " + "defined in nemo_curator.download.docbuilder" + ) + return iterator_class def import_extractor(extractor_path): - module_path, extractor_name = extractor_path.rsplit(".", 1) - extractor_module = importlib.import_module(module_path) - extractor_class = getattr(extractor_module, extractor_name) - if not issubclass(extractor_class, DocumentExtractor): - raise ValueError(f"Input extractor {extractor_class.__name__} " - "must be derived from DocumentExtractor defined " - "in nemo_curator.download.docbuilder") - return extractor_class - -def _download_and_extract_single_partition(paths: List[Tuple[str, str]], downloader: DocumentDownloader, iterator: DocumentIterator, extractor: DocumentExtractor, output_type: str, keep_raw_download: bool, force_download: bool) -> pd.DataFrame: - url, output_path = paths - - if os.path.exists(output_path) and not force_download: - partition = read_single_partition([output_path], backend="pandas", filetype=output_type, add_filename=True) + module_path, extractor_name = extractor_path.rsplit(".", 1) + extractor_module = importlib.import_module(module_path) + extractor_class = getattr(extractor_module, extractor_name) + if not issubclass(extractor_class, DocumentExtractor): + raise ValueError( + f"Input extractor {extractor_class.__name__} " + "must be derived from DocumentExtractor defined " + "in nemo_curator.download.docbuilder" + ) + return extractor_class + + +def _download_and_extract_single_partition( + paths: List[Tuple[str, str]], + downloader: DocumentDownloader, + iterator: DocumentIterator, + extractor: DocumentExtractor, + output_type: str, + keep_raw_download: bool, + force_download: bool, +) -> pd.DataFrame: + url, output_path = paths + + if os.path.exists(output_path) and not force_download: + partition = read_single_partition( + [output_path], backend="pandas", filetype=output_type, add_filename=True + ) + return partition + + downloaded_file = downloader.download(url) + records = [] + # Iterate over all records in file + for item in iterator.iterate(downloaded_file): + record_meta, content = item + # Extract the text from the record + extracted = extractor.extract(content) + if extracted is not None: + text_meta, text = extracted + if text is not None: + line = { + "text": text, + **text_meta, + **record_meta, + } + records.append(line) + + partition = pd.DataFrame(records) + filename = os.path.basename(output_path) + output_dir = os.path.dirname(output_path) + partition["filename"] = filename + single_partition_write_with_filename(partition, output_dir, output_type=output_type) + if not keep_raw_download: + os.remove(downloaded_file) + return partition - downloaded_file = downloader.download(url) - records = [] - # Iterate over all records in file - for item in iterator.iterate(downloaded_file): - record_meta, content = item - # Extract the text from the record - extracted = extractor.extract(content) - if extracted is not None: - text_meta, text = extracted - if text is not None: - line = { - 'text': text, - **text_meta, - **record_meta, - } - records.append(line) - - partition = pd.DataFrame(records) - filename = os.path.basename(output_path) - output_dir = os.path.dirname(output_path) - partition["filename"] = filename - single_partition_write_with_filename(partition, output_dir, output_type=output_type) - if not keep_raw_download: - os.remove(downloaded_file) - - return partition - -def download_and_extract(urls: List[str], output_paths: List[str], downloader: DocumentDownloader, iterator: DocumentIterator, extractor: DocumentExtractor, output_format: dict, output_type: str="jsonl", keep_raw_download=False, force_download=False) -> DocumentDataset: - """ - Downloads and extracts a dataset into a format accepted by the NeMo Curator - - Args: - urls: A list of urls to download the dataset from - output_paths: A list of paths to save the final extracted output to. - The raw output of the downloader will be saved using the path given by downloader.download(url). - downloader: A DocumentDownloader that handles retrieving each file from its url and saving it to storage - iterator: A DocumentIterator that handles iterating through the downloaded file's format - extractor: A DocumentExtractor that handles extracting the data from its raw format into text - output_format: A dictionary mappings columns to datatypes for the fields of each datapoint after extraction. - output_type: The file type to save the dataset as. - keep_raw_download: Whether to keep the pre-extracted download file. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - - Returns: - A DocumentDataset of the downloaded data - """ - if len(urls) != len(output_paths): - raise ValueError("Different number of urls and output_paths") - - output_format = dict(sorted(output_format.items())) - df = dd.from_map( - _download_and_extract_single_partition, - zip(urls, output_paths), - downloader=downloader, - iterator=iterator, - extractor=extractor, - output_type=output_type, - keep_raw_download=keep_raw_download, - force_download=force_download, - enforce_metadata=False, - meta=output_format, - ) - - return DocumentDataset(df) + +def download_and_extract( + urls: List[str], + output_paths: List[str], + downloader: DocumentDownloader, + iterator: DocumentIterator, + extractor: DocumentExtractor, + output_format: dict, + output_type: str = "jsonl", + keep_raw_download=False, + force_download=False, +) -> DocumentDataset: + """ + Downloads and extracts a dataset into a format accepted by the NeMo Curator + + Args: + urls: A list of urls to download the dataset from + output_paths: A list of paths to save the final extracted output to. + The raw output of the downloader will be saved using the path given by downloader.download(url). + downloader: A DocumentDownloader that handles retrieving each file from its url and saving it to storage + iterator: A DocumentIterator that handles iterating through the downloaded file's format + extractor: A DocumentExtractor that handles extracting the data from its raw format into text + output_format: A dictionary mappings columns to datatypes for the fields of each datapoint after extraction. + output_type: The file type to save the dataset as. + keep_raw_download: Whether to keep the pre-extracted download file. + force_download: If False, will skip processing all files in output_paths that already exist and + directly read from them instead. + + Returns: + A DocumentDataset of the downloaded data + """ + if len(urls) != len(output_paths): + raise ValueError("Different number of urls and output_paths") + + output_format = dict(sorted(output_format.items())) + df = dd.from_map( + _download_and_extract_single_partition, + zip(urls, output_paths), + downloader=downloader, + iterator=iterator, + extractor=extractor, + output_type=output_type, + keep_raw_download=keep_raw_download, + force_download=force_download, + enforce_metadata=False, + meta=output_format, + ) + + return DocumentDataset(df) + def batch_download(urls: List[str], downloader: DocumentDownloader) -> List[str]: - """ - Downloads all the urls using the downloader in parallel - """ - delayed_downloads = [delayed(downloader.download)(url) for url in urls] - - return compute(*delayed_downloads) \ No newline at end of file + """ + Downloads all the urls using the downloader in parallel + """ + delayed_downloads = [delayed(downloader.download)(url) for url in urls] + + return compute(*delayed_downloads) diff --git a/nemo_curator/download/wikipedia.py b/nemo_curator/download/wikipedia.py index 0df8059d..456142fb 100644 --- a/nemo_curator/download/wikipedia.py +++ b/nemo_curator/download/wikipedia.py @@ -12,25 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import bz2 +import codecs import os import re import subprocess -import bz2 -import codecs -import mwparserfromhell -from urllib.parse import urlparse, quote import xml.etree.cElementTree as etree +from urllib.parse import quote, urlparse + +import mwparserfromhell +from distributed import Lock +from nemo_curator.datasets import DocumentDataset from nemo_curator.download.doc_builder import ( DocumentDownloader, - DocumentIterator, DocumentExtractor, - download_and_extract + DocumentIterator, + download_and_extract, ) -from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.download_utils import get_wikipedia_urls from nemo_curator.utils.file_utils import expand_outdir_and_mkdir -from distributed import Lock # The majority of this code is taken from the HuggingFace # implementation of the Wikipedia dataset preparation: @@ -157,10 +158,7 @@ "koi": ["Медиа", "Файл", "Изображение"], "krc": ["Медиа", "Файл", "Изображение"], "ks": ["میڈیا", "فَیِل"], - "ksh": [ - "Beld", "Meedije", "Medie", "Belld", "Medium", "Datei", "Meedijum", - "Bild" - ], + "ksh": ["Beld", "Meedije", "Medie", "Belld", "Medium", "Datei", "Meedijum", "Bild"], "ku": ["میدیا", "پەڕگە", "Medya", "Wêne"], "kv": ["Медиа", "Файл", "Изображение"], "kw": ["Restren"], @@ -248,8 +246,14 @@ "sl": ["Slika", "Datoteka"], "sq": ["Figura", "Skeda"], "sr": [ - "Датотека", "Medij", "Slika", "Медија", "Datoteka", "Медиј", "Medija", - "Слика" + "Датотека", + "Medij", + "Slika", + "Медија", + "Datoteka", + "Медиј", + "Medija", + "Слика", ], "srn": ["Afbeelding", "Gefre"], "stq": ["Bielde", "Bild"], @@ -294,7 +298,19 @@ "zh": ["媒体文件", "F", "文件", "媒體", "档案", "图像", "圖像", "媒体", "檔案"], "zh-classical": ["文件", "媒體", "圖像", "檔案"], "zh-min-nan": ["tóng-àn", "文件", "媒體", "Mûi-thé", "圖像", "檔案"], - "zh-yue": ["檔", "档", "文件", "图", "媒體", "圖", "档案", "图像", "圖像", "媒体", "檔案"], + "zh-yue": [ + "檔", + "档", + "文件", + "图", + "媒體", + "圖", + "档案", + "图像", + "圖像", + "媒体", + "檔案", + ], } CAT_ALIASES = { @@ -420,8 +436,13 @@ "krc": ["Категория"], "ks": ["زٲژ"], "ksh": [ - "Saachjropp", "Saachjrop", "Katejori", "Kategorie", "Saachjrupp", - "Kattejori", "Sachjrop" + "Saachjropp", + "Saachjrop", + "Katejori", + "Kategorie", + "Saachjrupp", + "Kattejori", + "Sachjrop", ], "ku": ["Kategorî", "پۆل"], "kv": ["Категория"], @@ -558,207 +579,238 @@ class WikipediaDownloader(DocumentDownloader): - def __init__(self, download_dir, verbose=False): - super().__init__() - self._download_dir = download_dir - self._verbose = verbose - self._lock = Lock(name="wikipedia_downloader") - - def download(self, url): - urlpath = urlparse(url).path[1:] - output_name = urlpath.replace('/', '-') - output_file = os.path.join(self._download_dir, output_name) - if os.path.exists(output_file): - print(f"bz2 file: {output_file} exists. Not downloading") - else: - print(f"Downloading {url} and writing to {output_file}") - # Download with either wget or s5cmd (aws) - cmd = ["wget", url, "-O", output_file] - if self._verbose: - stdout, stderr = None, None - else: - stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL - with self._lock: - p = subprocess.run( - cmd, - stdout=stdout, - stderr=stderr, - ) - if p.returncode != 0: - print(f"Failed to download {url} to {output_file}") - - return output_file + def __init__(self, download_dir, verbose=False): + super().__init__() + self._download_dir = download_dir + self._verbose = verbose + self._lock = Lock(name="wikipedia_downloader") + + def download(self, url): + urlpath = urlparse(url).path[1:] + output_name = urlpath.replace("/", "-") + output_file = os.path.join(self._download_dir, output_name) + if os.path.exists(output_file): + print(f"bz2 file: {output_file} exists. Not downloading") + else: + print(f"Downloading {url} and writing to {output_file}") + # Download with either wget or s5cmd (aws) + cmd = ["wget", url, "-O", output_file] + if self._verbose: + stdout, stderr = None, None + else: + stdout, stderr = subprocess.DEVNULL, subprocess.DEVNULL + with self._lock: + p = subprocess.run( + cmd, + stdout=stdout, + stderr=stderr, + ) + if p.returncode != 0: + print(f"Failed to download {url} to {output_file}") + + return output_file class WikipediaIterator(DocumentIterator): - def __init__(self, language='en', log_frequency=1000): - super().__init__() - self._language = language - self._log_frequency = log_frequency - self._counter = 0 - - def iterate(self, file_path): - self._counter = 0 - bname = os.path.split(file_path)[-1] - input_file = bz2.BZ2File(filename=file_path) - utf_f = codecs.getreader("utf-8")(input_file) - context = etree.iterparse(utf_f, events=("end",)) - - for i, (unused_event, elem) in enumerate(context): - if not elem.tag.endswith("page"): - continue - if self._counter > 0 and self._counter % self._log_frequency == 0: - print(f"Extracted {self._counter} articles from {file_path}") - self._counter += 1 - - namespace = elem.tag[:-4] - title = elem.find(f"./{namespace}title").text - ns = elem.find(f"./{namespace}ns").text - id_ = elem.find(f"./{namespace}id").text - red_ = elem.find(f"./{namespace}redirect") - - url = f"https://{self._language}.wikipedia.org/wiki/{quote(title)}" - - # Filter pages that are not in the "main" namespace. - if ns != "0": - elem.clear() - continue - - raw_content = elem.find(f"./{namespace}revision/{namespace}text").text - elem.clear() - - # Filter redirects. - if raw_content is None or red_ is not None: - continue - - yield { - 'title': title, - 'id': id_, - 'url': url, - 'language': self._language, - 'source_id': f'{bname}', - }, raw_content + def __init__(self, language="en", log_frequency=1000): + super().__init__() + self._language = language + self._log_frequency = log_frequency + self._counter = 0 + + def iterate(self, file_path): + self._counter = 0 + bname = os.path.split(file_path)[-1] + input_file = bz2.BZ2File(filename=file_path) + utf_f = codecs.getreader("utf-8")(input_file) + context = etree.iterparse(utf_f, events=("end",)) + + for i, (unused_event, elem) in enumerate(context): + if not elem.tag.endswith("page"): + continue + if self._counter > 0 and self._counter % self._log_frequency == 0: + print(f"Extracted {self._counter} articles from {file_path}") + self._counter += 1 + + namespace = elem.tag[:-4] + title = elem.find(f"./{namespace}title").text + ns = elem.find(f"./{namespace}ns").text + id_ = elem.find(f"./{namespace}id").text + red_ = elem.find(f"./{namespace}redirect") + + url = f"https://{self._language}.wikipedia.org/wiki/{quote(title)}" + + # Filter pages that are not in the "main" namespace. + if ns != "0": + elem.clear() + continue + + raw_content = elem.find(f"./{namespace}revision/{namespace}text").text + elem.clear() + + # Filter redirects. + if raw_content is None or red_ is not None: + continue + + yield { + "title": title, + "id": id_, + "url": url, + "language": self._language, + "source_id": f"{bname}", + }, raw_content class WikipediaExtractor(DocumentExtractor): - def __init__(self, language='en', parser=mwparserfromhell): - super().__init__() - self._language = language - self._parser = parser - - def extract(self, content): - wikicode = self._parser.parse(content) - - # Filters for magic words that are parser instructions -- e.g., __NOTOC__ - re_rm_magic = re.compile("__[A-Z]*__", flags=re.UNICODE) - - # Filters for file/image links. - media_prefixes = "|".join(["File", "Image", "Media"] + - MEDIA_ALIASES.get(self._language, [])) - re_rm_wikilink = re.compile(f"^(?:{media_prefixes}):", - flags=re.IGNORECASE | re.UNICODE) - - def rm_wikilink(obj): - return bool(re_rm_wikilink.match(str(obj.title))) - - # Filters for references and tables - def rm_tag(obj): - return str(obj.tag) in {"ref", "table"} - - # Leave category links in-place but remove the category prefixes - cat_prefixes = "|".join(["Category"] + CAT_ALIASES.get(self._language, [])) - re_clean_wikilink = re.compile(f"^(?:{cat_prefixes}):", - flags=re.IGNORECASE | re.UNICODE) - - def is_category(obj): - return bool(re_clean_wikilink.match(str(obj.title))) - - def clean_wikilink(obj): - text = obj.__strip__() - text = re.sub(re_clean_wikilink, "", text) - obj.text = text - - def try_replace_obj(obj): - try: - clean_wikilink(obj) - except ValueError: - # For unknown reasons, objects are sometimes not found. - pass - - def try_remove_obj(obj, section): - try: - section.remove(obj) - except ValueError: - # For unknown reasons, objects are sometimes not found. - pass - - section_text = [] - # Filter individual sections to clean. - wiki_code_kwargs = { - 'flat': True, - 'include_lead': True, - 'include_headings': True, + def __init__(self, language="en", parser=mwparserfromhell): + super().__init__() + self._language = language + self._parser = parser + + def extract(self, content): + wikicode = self._parser.parse(content) + + # Filters for magic words that are parser instructions -- e.g., __NOTOC__ + re_rm_magic = re.compile("__[A-Z]*__", flags=re.UNICODE) + + # Filters for file/image links. + media_prefixes = "|".join( + ["File", "Image", "Media"] + MEDIA_ALIASES.get(self._language, []) + ) + re_rm_wikilink = re.compile( + f"^(?:{media_prefixes}):", flags=re.IGNORECASE | re.UNICODE + ) + + def rm_wikilink(obj): + return bool(re_rm_wikilink.match(str(obj.title))) + + # Filters for references and tables + def rm_tag(obj): + return str(obj.tag) in {"ref", "table"} + + # Leave category links in-place but remove the category prefixes + cat_prefixes = "|".join(["Category"] + CAT_ALIASES.get(self._language, [])) + re_clean_wikilink = re.compile( + f"^(?:{cat_prefixes}):", flags=re.IGNORECASE | re.UNICODE + ) + + def is_category(obj): + return bool(re_clean_wikilink.match(str(obj.title))) + + def clean_wikilink(obj): + text = obj.__strip__() + text = re.sub(re_clean_wikilink, "", text) + obj.text = text + + def try_replace_obj(obj): + try: + clean_wikilink(obj) + except ValueError: + # For unknown reasons, objects are sometimes not found. + pass + + def try_remove_obj(obj, section): + try: + section.remove(obj) + except ValueError: + # For unknown reasons, objects are sometimes not found. + pass + + section_text = [] + # Filter individual sections to clean. + wiki_code_kwargs = { + "flat": True, + "include_lead": True, + "include_headings": True, + } + for section in wikicode.get_sections(**wiki_code_kwargs): + for obj in section.ifilter_wikilinks(recursive=True): + if rm_wikilink(obj): + try_remove_obj(obj, section) + elif is_category(obj): + try_replace_obj(obj) + for obj in section.ifilter_tags(matches=rm_tag, recursive=True): + try_remove_obj(obj, section) + + section_text.append( + re.sub( + re_rm_magic, + "", + section.strip_code().strip(), + ) + ) + # Don't return any meta here + return {}, "\n\n".join(section_text) + + +def download_wikipedia( + output_path: str, + language: str = "en", + dump_date=None, + output_type: str = "jsonl", + raw_download_dir=None, + keep_raw_download=False, + force_download=False, + url_limit=None, +) -> DocumentDataset: + """ + Downloads the latest Wikipedia dumps and extracts them using mwparserfromhell + + Args: + output_path: The path to the root directory of the files + language: The language of the Wikipedia articles to download + dump_date: A string formatted as "YYYYMMDD" for the wikipedia dump to use. + If None, latest dump is used. + output_type: The file type to save the data as. + raw_download_dir: Path to store the raw download files for intermediate processing. + If None, they are stored in a folder named "downloads" under output_path. + keep_raw_download: If True, keeps the bz2 files that have not been extracted. + force_download: If False, will skip processing all files in output_paths that already exist and + directly read from them instead. + url_limit: The maximum number of raw files to download from the snapshot. If None, all + files from the range of snapshots are downloaded. + """ + wikipedia_urls = get_wikipedia_urls(language=language, dump_date=dump_date) + if url_limit: + wikipedia_urls = wikipedia_urls[:url_limit] + output_paths = list( + map( + lambda url: os.path.join( + output_path, url.split("/")[-1] + f".{output_type}" + ), + wikipedia_urls, + ) + ) + + if not raw_download_dir: + raw_download_dir = os.path.join(output_path, "downloads") + expand_outdir_and_mkdir(raw_download_dir) + + downloader = WikipediaDownloader(raw_download_dir) + iterator = WikipediaIterator(language=language) + extractor = WikipediaExtractor(language=language) + + output_format = { + "text": str, + "title": str, + "id": str, + "url": str, + "language": str, + "source_id": str, + "filename": str, } - for section in wikicode.get_sections(**wiki_code_kwargs): - for obj in section.ifilter_wikilinks(recursive=True): - if rm_wikilink(obj): - try_remove_obj(obj, section) - elif is_category(obj): - try_replace_obj(obj) - for obj in section.ifilter_tags(matches=rm_tag, recursive=True): - try_remove_obj(obj, section) - - section_text.append(re.sub( - re_rm_magic, - "", - section.strip_code().strip(), - )) - # Don't return any meta here - return {}, "\n\n".join(section_text) - - -def download_wikipedia(output_path: str, language: str="en", dump_date=None, output_type: str="jsonl", raw_download_dir=None, keep_raw_download=False, force_download=False, url_limit=None) -> DocumentDataset: - """ - Downloads the latest Wikipedia dumps and extracts them using mwparserfromhell - - Args: - output_path: The path to the root directory of the files - language: The language of the Wikipedia articles to download - dump_date: A string formatted as "YYYYMMDD" for the wikipedia dump to use. - If None, latest dump is used. - output_type: The file type to save the data as. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the bz2 files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. - """ - wikipedia_urls = get_wikipedia_urls(language=language, dump_date=dump_date) - if url_limit: - wikipedia_urls = wikipedia_urls[:url_limit] - output_paths = list(map(lambda url: os.path.join(output_path, url.split("/")[-1] + f".{output_type}"), wikipedia_urls)) - - if not raw_download_dir: - raw_download_dir = os.path.join(output_path, "downloads") - expand_outdir_and_mkdir(raw_download_dir) - - downloader = WikipediaDownloader(raw_download_dir) - iterator = WikipediaIterator(language=language) - extractor = WikipediaExtractor(language=language) - - output_format = { - "text": str, - "title": str, - "id": str, - "url": str, - "language": str, - "source_id": str, - "filename": str, - } - dataset = download_and_extract(wikipedia_urls, output_paths, downloader, iterator, extractor, output_format, output_type=output_type, keep_raw_download=keep_raw_download, force_download=force_download) - - return dataset \ No newline at end of file + dataset = download_and_extract( + wikipedia_urls, + output_paths, + downloader, + iterator, + extractor, + output_format, + output_type=output_type, + keep_raw_download=keep_raw_download, + force_download=force_download, + ) + + return dataset diff --git a/nemo_curator/filters/__init__.py b/nemo_curator/filters/__init__.py index 18f46afd..488df5ee 100644 --- a/nemo_curator/filters/__init__.py +++ b/nemo_curator/filters/__init__.py @@ -12,10 +12,81 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .classifier_filter import ( + BatchedFastTextQualityFilter, + FastTextLangId, + FastTextQualityFilter, +) +from .code import ( + AlphaFilter, + GeneralCommentToCodeFilter, + HTMLBoilerplateFilter, + NumberOfLinesOfCodeFilter, + PerExtensionFilter, + PythonCommentToCodeFilter, + TokenizerFertilityFilter, + XMLHeaderFilter, +) from .doc_filter import DocumentFilter, import_filter -from .classifier_filter import FastTextLangId, FastTextQualityFilter, BatchedFastTextQualityFilter -from .heuristic_filter import NonAlphaNumericFilter, SymbolsToWordsFilter, NumbersFilter, UrlsFilter, BulletsFilter, WhiteSpaceFilter, ParenthesesFilter, LongWordFilter, WordCountFilter, BoilerPlateStringFilter, MeanWordLengthFilter, RepeatedLinesFilter, RepeatedParagraphsFilter, RepeatedLinesByCharFilter, RepeatedParagraphsByCharFilter, RepeatingTopNGramsFilter, RepeatingDuplicateNGramsFilter, PunctuationFilter, EllipsisFilter, CommonEnglishWordsFilter, WordsWithoutAlphabetsFilter, PornographicUrlsFilter -from .code import PythonCommentToCodeFilter, GeneralCommentToCodeFilter, NumberOfLinesOfCodeFilter, TokenizerFertilityFilter, XMLHeaderFilter, AlphaFilter, HTMLBoilerplateFilter, PerExtensionFilter - -__all__ = ["BatchedFastTextQualityFilter", "DocumentFilter", "import_filter", "FastTextLangId", "FastTextQualityFilter", "NonAlphaNumericFilter", "SymbolsToWordsFilter", "NumbersFilter", "UrlsFilter", "BulletsFilter", "WhiteSpaceFilter", "ParenthesesFilter", "LongWordFilter", "WordCountFilter", "BoilerPlateStringFilter", "MeanWordLengthFilter", "RepeatedLinesFilter", "RepeatedParagraphsFilter", "RepeatedLinesByCharFilter", "RepeatedParagraphsByCharFilter", "RepeatingTopNGramsFilter", "RepeatingDuplicateNGramsFilter", "PunctuationFilter", "EllipsisFilter", "CommonEnglishWordsFilter", "WordsWithoutAlphabetsFilter", "PornographicUrlsFilter", "PythonCommentToCodeFilter", "GeneralCommentToCodeFilter", "NumberOfLinesOfCodeFilter", "TokenizerFertilityFilter", "XMLHeaderFilter", "AlphaFilter", "HTMLBoilerplateFilter", "PerExtensionFilter"] +from .heuristic_filter import ( + BoilerPlateStringFilter, + BulletsFilter, + CommonEnglishWordsFilter, + EllipsisFilter, + LongWordFilter, + MeanWordLengthFilter, + NonAlphaNumericFilter, + NumbersFilter, + ParenthesesFilter, + PornographicUrlsFilter, + PunctuationFilter, + RepeatedLinesByCharFilter, + RepeatedLinesFilter, + RepeatedParagraphsByCharFilter, + RepeatedParagraphsFilter, + RepeatingDuplicateNGramsFilter, + RepeatingTopNGramsFilter, + SymbolsToWordsFilter, + UrlsFilter, + WhiteSpaceFilter, + WordCountFilter, + WordsWithoutAlphabetsFilter, +) +__all__ = [ + "BatchedFastTextQualityFilter", + "DocumentFilter", + "import_filter", + "FastTextLangId", + "FastTextQualityFilter", + "NonAlphaNumericFilter", + "SymbolsToWordsFilter", + "NumbersFilter", + "UrlsFilter", + "BulletsFilter", + "WhiteSpaceFilter", + "ParenthesesFilter", + "LongWordFilter", + "WordCountFilter", + "BoilerPlateStringFilter", + "MeanWordLengthFilter", + "RepeatedLinesFilter", + "RepeatedParagraphsFilter", + "RepeatedLinesByCharFilter", + "RepeatedParagraphsByCharFilter", + "RepeatingTopNGramsFilter", + "RepeatingDuplicateNGramsFilter", + "PunctuationFilter", + "EllipsisFilter", + "CommonEnglishWordsFilter", + "WordsWithoutAlphabetsFilter", + "PornographicUrlsFilter", + "PythonCommentToCodeFilter", + "GeneralCommentToCodeFilter", + "NumberOfLinesOfCodeFilter", + "TokenizerFertilityFilter", + "XMLHeaderFilter", + "AlphaFilter", + "HTMLBoilerplateFilter", + "PerExtensionFilter", +] diff --git a/nemo_curator/filters/classifier_filter.py b/nemo_curator/filters/classifier_filter.py index 74d54ad7..43c5bf9e 100644 --- a/nemo_curator/filters/classifier_filter.py +++ b/nemo_curator/filters/classifier_filter.py @@ -16,107 +16,114 @@ import numpy as np import pandas as pd -from nemo_curator.filters import DocumentFilter -from nemo_curator.utils.distributed_utils import load_object_on_worker, NoWorkerError +from nemo_curator.filters.doc_filter import DocumentFilter +from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker class FastTextQualityFilter(DocumentFilter): - def __init__(self, model_path=None, label='__label__hq', alpha=3, seed=42): - if model_path is None: - raise ValueError("Must provide a valid path to a FastText model " - "to compute document scores with this filter") - self._model_path = model_path - self._label = label - self._alpha = alpha - self._seed = np.random.seed(seed) - self._name = 'fasttext_quality_filter' - - def score_document(self, text): - text = text.replace('\n', ' ').replace('__label__', ' ') - model_attr = f"{self._name}_{self._model_path}" - # Workers don't exist during type inference - try: - model = load_object_on_worker(model_attr, self._load_model, {}) - except NoWorkerError: - return 1.0 - pred = model.predict(text) - document_score = pred[1][0] - if pred[0][0] != self._label: - document_score = 1 - document_score - - return document_score - - def keep_document(self, score): - return np.random.pareto(self._alpha) > 1 - score - - def _load_model(self): - return fasttext.load_model(self._model_path) + def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42): + if model_path is None: + raise ValueError( + "Must provide a valid path to a FastText model " + "to compute document scores with this filter" + ) + self._model_path = model_path + self._label = label + self._alpha = alpha + self._seed = np.random.seed(seed) + self._name = "fasttext_quality_filter" + + def score_document(self, text): + text = text.replace("\n", " ").replace("__label__", " ") + model_attr = f"{self._name}_{self._model_path}" + # Workers don't exist during type inference + try: + model = load_object_on_worker(model_attr, self._load_model, {}) + except NoWorkerError: + return 1.0 + pred = model.predict(text) + document_score = pred[1][0] + if pred[0][0] != self._label: + document_score = 1 - document_score + + return document_score + + def keep_document(self, score): + return np.random.pareto(self._alpha) > 1 - score + + def _load_model(self): + return fasttext.load_model(self._model_path) + class BatchedFastTextQualityFilter(DocumentFilter): - def __init__(self, model_path=None, label='__label__hq', alpha=3, seed=42): - if model_path is None: - raise ValueError("Must provide a valid path to a FastText model " - "to compute document scores with this filter") - self._model_path = model_path - self._label = label - self._alpha = alpha - self._seed = np.random.seed(seed) - self._name = 'fasttext_quality_filter' - - def score_document(self, df): - model_attr = f"{self._name}_{self._model_path}" - try: - model = load_object_on_worker(model_attr, self._load_model, {}) - except NoWorkerError: - return pd.Series(np.ones(len(df)), dtype=float) - - def _score_document(text): - text = text.replace('\n', ' ').replace('__label__', ' ') - pred = model.predict(text) - document_score = pred[1][0] - if pred[0][0] != self._label: - document_score = 1 - document_score - - return document_score - - return df.apply(_score_document) - - def keep_document(self, df): - return np.random.pareto(self._alpha, size=len(df)) > 1 - df - - def _load_model(self): - return fasttext.load_model(self._model_path) + def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42): + if model_path is None: + raise ValueError( + "Must provide a valid path to a FastText model " + "to compute document scores with this filter" + ) + self._model_path = model_path + self._label = label + self._alpha = alpha + self._seed = np.random.seed(seed) + self._name = "fasttext_quality_filter" + + def score_document(self, df): + model_attr = f"{self._name}_{self._model_path}" + try: + model = load_object_on_worker(model_attr, self._load_model, {}) + except NoWorkerError: + return pd.Series(np.ones(len(df)), dtype=float) + + def _score_document(text): + text = text.replace("\n", " ").replace("__label__", " ") + pred = model.predict(text) + document_score = pred[1][0] + if pred[0][0] != self._label: + document_score = 1 - document_score + + return document_score + + return df.apply(_score_document) + + def keep_document(self, df): + return np.random.pareto(self._alpha, size=len(df)) > 1 - df + + def _load_model(self): + return fasttext.load_model(self._model_path) class FastTextLangId(DocumentFilter): - def __init__(self, model_path=None, min_langid_score=0.3): - if model_path is None: - raise ValueError("Must provide a valid path to a FastText model " - "to identify languages with this filter") - self._model_path = model_path - self._lang_code = None - self._cutoff = min_langid_score - self._name = "lang_id" - - def score_document(self, text): - pp = text.strip().replace('\n', ' ') - - model_attr = f"{self._name}_{self._model_path}" - try: - model = load_object_on_worker(model_attr, self._load_model, {}) - except NoWorkerError: - return [1.0, 'N/A'] - label, score = model.predict(pp, k=1) - score = score[0] - lang_code = label[0][-2:].upper() - - return [score, lang_code] - - def keep_document(self, score): - return score[0] >= self._cutoff - - def _load_model(self): - return fasttext.load_model(self._model_path) + def __init__(self, model_path=None, min_langid_score=0.3): + if model_path is None: + raise ValueError( + "Must provide a valid path to a FastText model " + "to identify languages with this filter" + ) + self._model_path = model_path + self._lang_code = None + self._cutoff = min_langid_score + self._name = "lang_id" + + def score_document(self, text): + pp = text.strip().replace("\n", " ") + + model_attr = f"{self._name}_{self._model_path}" + try: + model = load_object_on_worker(model_attr, self._load_model, {}) + except NoWorkerError: + return [1.0, "N/A"] + label, score = model.predict(pp, k=1) + score = score[0] + lang_code = label[0][-2:].upper() + + return [score, lang_code] + + def keep_document(self, score): + return score[0] >= self._cutoff + + def _load_model(self): + return fasttext.load_model(self._model_path) diff --git a/nemo_curator/filters/code.py b/nemo_curator/filters/code.py index 44868c29..9a209ec4 100644 --- a/nemo_curator/filters/code.py +++ b/nemo_curator/filters/code.py @@ -12,292 +12,309 @@ # See the License for the specific language governing permissions and # limitations under the License. +import csv import warnings -from comment_parser import comment_parser + import numpy as np -import csv from bs4 import BeautifulSoup +from comment_parser import comment_parser from nemo.collections.common.tokenizers import SentencePieceTokenizer -from nemo_curator.filters import DocumentFilter, import_filter -from nemo_curator.utils.text_utils import get_comments_and_docstring +from nemo_curator.filters.doc_filter import DocumentFilter, import_filter from nemo_curator.utils.constants import regex_alpha, regex_alphanum +from nemo_curator.utils.text_utils import get_comments_and_docstring class PythonCommentToCodeFilter(DocumentFilter): - def __init__( - self, - min_comment_to_code_ratio=0.01, - max_comment_to_code_ratio=0.85, - ): - self._min_threshold = min_comment_to_code_ratio - self._max_threshold = max_comment_to_code_ratio + def __init__( + self, + min_comment_to_code_ratio=0.01, + max_comment_to_code_ratio=0.85, + ): + self._min_threshold = min_comment_to_code_ratio + self._max_threshold = max_comment_to_code_ratio - self._name = 'python_comment_ratio' + self._name = "python_comment_ratio" - def score_document(self, source): - docstrings, comments = get_comments_and_docstring(source) - if docstrings is None or comments is None: - return 0 - # No need to be super precise about the way this formatted, - # we just need to count the number of characters. - return (len(comments) + len(docstrings)) / len(source) + def score_document(self, source): + docstrings, comments = get_comments_and_docstring(source) + if docstrings is None or comments is None: + return 0 + # No need to be super precise about the way this formatted, + # we just need to count the number of characters. + return (len(comments) + len(docstrings)) / len(source) - def keep_document(self, score): - return self._min_threshold <= score <= self._max_threshold + def keep_document(self, score): + return self._min_threshold <= score <= self._max_threshold class GeneralCommentToCodeFilter(DocumentFilter): - def __init__( - self, - language, - min_comment_to_code_ratio=0.01, - max_comment_to_code_ratio=0.85, - ): - """ - Does not include the comment characters (// or /**/) towards the length of the comment. - Args: - language: Mime string of language - """ - self._lang = language - self._min_threshold = min_comment_to_code_ratio = min_comment_to_code_ratio - self._max_threshold = max_comment_to_code_ratio = max_comment_to_code_ratio - - self._name = 'comment_ratio' - - def score_document(self, source): - try: - comments = comment_parser.extract_comments_from_str( - source, - mime=self._lang, - ) - comments = ' '.join([x.text() for x in comments]) - except Exception: - warnings.warn("tokenization error, no comment is extracted") - return 9999 - if comments is None: - return 0 - return len(comments) / len(source) - - def keep_document(self, score): - return self._min_threshold <= score <= self._max_threshold + def __init__( + self, + language, + min_comment_to_code_ratio=0.01, + max_comment_to_code_ratio=0.85, + ): + """ + Does not include the comment characters (// or /**/) towards the length of the comment. + Args: + language: Mime string of language + """ + self._lang = language + self._min_threshold = min_comment_to_code_ratio = min_comment_to_code_ratio + self._max_threshold = max_comment_to_code_ratio = max_comment_to_code_ratio + + self._name = "comment_ratio" + + def score_document(self, source): + try: + comments = comment_parser.extract_comments_from_str( + source, + mime=self._lang, + ) + comments = " ".join([x.text() for x in comments]) + except Exception: + warnings.warn("tokenization error, no comment is extracted") + return 9999 + if comments is None: + return 0 + return len(comments) / len(source) + + def keep_document(self, score): + return self._min_threshold <= score <= self._max_threshold class NumberOfLinesOfCodeFilter(DocumentFilter): - def __init__(self, min_lines=10, max_lines=20000): - self._min_lines = min_lines - self._max_lines = max_lines + def __init__(self, min_lines=10, max_lines=20000): + self._min_lines = min_lines + self._max_lines = max_lines - self._name = 'num_lines' + self._name = "num_lines" - def score_document(self, source): - return len(source.split('\n')) + def score_document(self, source): + return len(source.split("\n")) - def keep_document(self, score): - return self._min_lines <= score <= self._max_lines + def keep_document(self, score): + return self._min_lines <= score <= self._max_lines class TokenizerFertilityFilter(DocumentFilter): - def __init__(self, path_to_tokenizer=None, min_char_to_token_ratio=2.5): - if path_to_tokenizer is None: - raise ValueError("Must provide a valid path to a SentencePiece " - "tokenizer") - self._tokenizer = SentencePieceTokenizer(path_to_tokenizer) - self._threshold = min_char_to_token_ratio + def __init__(self, path_to_tokenizer=None, min_char_to_token_ratio=2.5): + if path_to_tokenizer is None: + raise ValueError( + "Must provide a valid path to a SentencePiece " "tokenizer" + ) + self._tokenizer = SentencePieceTokenizer(path_to_tokenizer) + self._threshold = min_char_to_token_ratio - self._name = 'tokenizer_fertility' + self._name = "tokenizer_fertility" - def score_document(self, source): - tokens = self._tokenizer.text_to_tokens(source) - num_chars = len(source) - num_tokens = len(tokens) - if num_tokens == 0: - return -1 - return num_chars / num_tokens + def score_document(self, source): + tokens = self._tokenizer.text_to_tokens(source) + num_chars = len(source) + num_tokens = len(tokens) + if num_tokens == 0: + return -1 + return num_chars / num_tokens - def keep_document(self, score): - return score >= self._threshold + def keep_document(self, score): + return score >= self._threshold class XMLHeaderFilter(DocumentFilter): - """ - This filter tries to identify files that have incorrect file extensions. - In many cases, these end up being XML files and we try to identify them - based on the header. - (Source: Starcoder https://arxiv.org/abs/2305.06161) - """ - def __init__(self, char_prefix_search_length=100): - self._char_prefix_search_length = char_prefix_search_length + """ + This filter tries to identify files that have incorrect file extensions. + In many cases, these end up being XML files and we try to identify them + based on the header. + (Source: Starcoder https://arxiv.org/abs/2305.06161) + """ - self._name = 'xml_header' + def __init__(self, char_prefix_search_length=100): + self._char_prefix_search_length = char_prefix_search_length - def score_document(self, source): - source_prefix = source[:self._char_prefix_search_length] - if "= self._min_alpha_ratio + def score_document(self, source): + return len(regex_alpha.findall(source)) / len(source) + + def keep_document(self, score): + return score >= self._min_alpha_ratio class HTMLBoilerplateFilter(DocumentFilter): - """ - This filter tries to identify HTML that is largely boilerplate. - """ - def __init__(self, min_lang_content_ratio=0.2, min_lang_content_num_chars=100): - self._min_lang_content_ratio = min_lang_content_ratio - self._min_lang_content_num_chars = min_lang_content_num_chars + """ + This filter tries to identify HTML that is largely boilerplate. + """ + + def __init__(self, min_lang_content_ratio=0.2, min_lang_content_num_chars=100): + self._min_lang_content_ratio = min_lang_content_ratio + self._min_lang_content_num_chars = min_lang_content_num_chars + + self._name = "html_boilerplate" - self._name = 'html_boilerplate' + def score_document(self, source): + try: + soup = BeautifulSoup(source, features="html.parser") + except (TypeError, UnboundLocalError): + return None - def score_document(self, source): - try: - soup = BeautifulSoup(source, features="html.parser") - except (TypeError, UnboundLocalError): - return None - - # kill all script and style elements - for script in soup(["script", "style"]): - script.extract() # rip it out + # kill all script and style elements + for script in soup(["script", "style"]): + script.extract() # rip it out - # get text - text = soup.get_text() - ratio = len(text) / len(source) + # get text + text = soup.get_text() + ratio = len(text) / len(source) - if len(text) < self._min_lang_content_num_chars: - return 0 + if len(text) < self._min_lang_content_num_chars: + return 0 - return ratio + return ratio - def keep_document(self, score): - return score >= self._min_lang_content_ratio + def keep_document(self, score): + return score >= self._min_lang_content_ratio class PerExtensionFilter(DocumentFilter): - """ - This filter that has specific conditions depending on the file extension. - """ - def __init__(self, lang, extension, metadata_file='code_meta.csv'): - self._metadata_file = metadata_file - self._lang = lang - self._extension = extension - self._ext_to_filter = self._load_filter_csv(metadata_file, lang) - self._name = 'per_extension_filter' - - def _load_filter_csv(self, path: str, language: str = None): - """Load csv file that specifies the filter to apply for each (lang, extension). - TODO: add some tests. Check that filters are correctly set.""" - # (Lang, extension) -> filter_args - ext_to_filter = {} - with open(path) as f: - for row in csv.DictReader(f): - # Only take the rows corresponding to the language if specified - if language is None or row["language"] == language: - ext_to_filter[(row["language"], row["extension"])] = self._get_filter_params(row) - assert len(ext_to_filter) > 0, f"Did not find filtering params corresponding to language: `{language}` in: {path}" - - return ext_to_filter - - def _get_filter_params(self, row: dict): - """Extract filter parameters from csv row""" - include = row["Include"] == "1" - try: - line_max = int(row["Long_line_threshold"]) - except ValueError: - line_max = None - line_mean = 100 if line_max else None - try: - alphanum_frac = float(row["Alphanum_threshold"]) - except ValueError: - alphanum_frac = None - try: - alphabetic_frac = float(row["Alpha filter"]) - except ValueError: - alphabetic_frac = None - return include, line_max, line_mean, alphanum_frac, alphabetic_frac - - def _language_format_from_dataset(self, lang: str): - """Convert: Language field in dataset -> language field in csv file that defines the filters.""" - # TODO: other special cases? - if lang == "C#": - return "c-sharp" - if lang == "F#": - return "f-sharp" - return lang.lower().replace(" ", "-") - - def _line_statistics(self, source): - lengths = [len(x) for x in source.split('\n')] - max_length = max(lengths) - mean_length = np.mean(lengths) - - return max_length, mean_length - - def _alphanum_fraction(self, source): - return len(regex_alphanum.findall(source)) / len(source) - - def score_document(self, source): - """Filter files based on line length and % alphanumeric characters. - The filtering parameters depend on the file extension, given by `ext_to_filter`""" - # Get the filter-params we want to use - # extension `None` is an empty string in the csv - try: - (include, line_max, line_mean, alphanum_frac, alphabetic_frac) = self._ext_to_filter[(self._language_format_from_dataset( - self._lang), self._extension if self._extension is not None else "" - )] - except KeyError as e: - # Some extensions are not in the csv. This happens for dockerfiles. - # Exclude these files - print(str(e) + f":{self._extension} not in ext_to_filter") - include = False - - if not include: - return 0 - - max_length, mean_length = self._line_statistics(source) - - if line_max and max_length > line_max: - return 0 - - elif line_mean and mean_length > line_mean: - return 0 - - # Filter files with low percentage of alphanumeric chars - elif alphanum_frac and self._alphanum_fraction(source) < alphanum_frac: - return 0 - - # Filter files with low percentage of alphabetic chars - elif alphabetic_frac and sum(map(str.isalpha, source)) < alphabetic_frac * len(source): - return 0 - - return 1 - - def keep_document(self, score): - if score is None or score == 0: - return False - else: - return True + """ + This filter that has specific conditions depending on the file extension. + """ + def __init__(self, lang, extension, metadata_file="code_meta.csv"): + self._metadata_file = metadata_file + self._lang = lang + self._extension = extension + self._ext_to_filter = self._load_filter_csv(metadata_file, lang) + self._name = "per_extension_filter" + + def _load_filter_csv(self, path: str, language: str = None): + """Load csv file that specifies the filter to apply for each (lang, extension). + TODO: add some tests. Check that filters are correctly set.""" + # (Lang, extension) -> filter_args + ext_to_filter = {} + with open(path) as f: + for row in csv.DictReader(f): + # Only take the rows corresponding to the language if specified + if language is None or row["language"] == language: + ext_to_filter[(row["language"], row["extension"])] = ( + self._get_filter_params(row) + ) + assert ( + len(ext_to_filter) > 0 + ), f"Did not find filtering params corresponding to language: `{language}` in: {path}" + + return ext_to_filter + + def _get_filter_params(self, row: dict): + """Extract filter parameters from csv row""" + include = row["Include"] == "1" + try: + line_max = int(row["Long_line_threshold"]) + except ValueError: + line_max = None + line_mean = 100 if line_max else None + try: + alphanum_frac = float(row["Alphanum_threshold"]) + except ValueError: + alphanum_frac = None + try: + alphabetic_frac = float(row["Alpha filter"]) + except ValueError: + alphabetic_frac = None + return include, line_max, line_mean, alphanum_frac, alphabetic_frac + + def _language_format_from_dataset(self, lang: str): + """Convert: Language field in dataset -> language field in csv file that defines the filters.""" + # TODO: other special cases? + if lang == "C#": + return "c-sharp" + if lang == "F#": + return "f-sharp" + return lang.lower().replace(" ", "-") + + def _line_statistics(self, source): + lengths = [len(x) for x in source.split("\n")] + max_length = max(lengths) + mean_length = np.mean(lengths) + + return max_length, mean_length + + def _alphanum_fraction(self, source): + return len(regex_alphanum.findall(source)) / len(source) + + def score_document(self, source): + """Filter files based on line length and % alphanumeric characters. + The filtering parameters depend on the file extension, given by `ext_to_filter` + """ + # Get the filter-params we want to use + # extension `None` is an empty string in the csv + try: + (include, line_max, line_mean, alphanum_frac, alphabetic_frac) = ( + self._ext_to_filter[ + ( + self._language_format_from_dataset(self._lang), + self._extension if self._extension is not None else "", + ) + ] + ) + except KeyError as e: + # Some extensions are not in the csv. This happens for dockerfiles. + # Exclude these files + print(str(e) + f":{self._extension} not in ext_to_filter") + include = False + + if not include: + return 0 + + max_length, mean_length = self._line_statistics(source) + + if line_max and max_length > line_max: + return 0 + + elif line_mean and mean_length > line_mean: + return 0 + + # Filter files with low percentage of alphanumeric chars + elif alphanum_frac and self._alphanum_fraction(source) < alphanum_frac: + return 0 + + # Filter files with low percentage of alphabetic chars + elif alphabetic_frac and sum(map(str.isalpha, source)) < alphabetic_frac * len( + source + ): + return 0 + + return 1 + + def keep_document(self, score): + if score is None or score == 0: + return False + else: + return True diff --git a/nemo_curator/filters/doc_filter.py b/nemo_curator/filters/doc_filter.py index c6013e47..4ec10ccc 100644 --- a/nemo_curator/filters/doc_filter.py +++ b/nemo_curator/filters/doc_filter.py @@ -18,55 +18,57 @@ class DocumentFilter(ABC): - def __init__(self): - super().__init__() - self._name = self.__class__.__name__ - self._sentences = None - self._paragraphs = None - self._ngrams = None + def __init__(self): + super().__init__() + self._name = self.__class__.__name__ + self._sentences = None + self._paragraphs = None + self._ngrams = None - @abstractmethod - def score_document(self, text): - pass + @abstractmethod + def score_document(self, text): + pass - @abstractmethod - def keep_document(self, scores): - pass + @abstractmethod + def keep_document(self, scores): + pass - @property - def name(self): - return self._name + @property + def name(self): + return self._name - @property - def sentences(self): - return self._sentences + @property + def sentences(self): + return self._sentences - @sentences.setter - def sentences(self, sentences): - self._sentences = sentences + @sentences.setter + def sentences(self, sentences): + self._sentences = sentences - @property - def paragraphs(self): - return self._paragraphs + @property + def paragraphs(self): + return self._paragraphs - @paragraphs.setter - def paragraphs(self, paragraphs): - self._paragraphs = paragraphs + @paragraphs.setter + def paragraphs(self, paragraphs): + self._paragraphs = paragraphs - @property - def ngrams(self): - return self._ngrams + @property + def ngrams(self): + return self._ngrams - @ngrams.setter - def ngrams(self, ngrams): - self._ngrams = ngrams + @ngrams.setter + def ngrams(self, ngrams): + self._ngrams = ngrams def import_filter(filter_path): - module_path, filter_name = filter_path.rsplit(".", 1) - filter_module = importlib.import_module(module_path) - filter_class = getattr(filter_module, filter_name) - if not issubclass(filter_class, DocumentFilter): - raise ValueError(f"Input filter {filter_class.__name__} must be derived " - "from DocumentFilter defined in nemo_curator.filters.doc_filter") - return filter_class + module_path, filter_name = filter_path.rsplit(".", 1) + filter_module = importlib.import_module(module_path) + filter_class = getattr(filter_module, filter_name) + if not issubclass(filter_class, DocumentFilter): + raise ValueError( + f"Input filter {filter_class.__name__} must be derived " + "from DocumentFilter defined in nemo_curator.filters.doc_filter" + ) + return filter_class diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index 5936ede9..62638440 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -14,623 +14,622 @@ import regex -from nemo_curator.filters import DocumentFilter, import_filter +from nemo_curator.filters.doc_filter import DocumentFilter, import_filter from nemo_curator.utils.constants import ( + bullet_list, + common_english_words, + ellipsis_marks, + end_marks, + policy_substrings, + regex_alpha, regex_alphanum, - regex_hash, regex_digit, - regex_url, + regex_hash, regex_paren, - regex_alpha, - ellipsis_marks, - bullet_list, + regex_url, white_space_list, - policy_substrings, - end_marks, - common_english_words, ) from nemo_curator.utils.text_utils import ( - get_word_splitter, + get_ngrams, get_paragraphs, get_sentences, - get_ngrams, + get_word_splitter, is_paragraph_indices_in_top_or_bottom_only, ) class NonAlphaNumericFilter(DocumentFilter): - """ - If more than 25% of the document is non-alphanumeric then discard - Intended to be applied only too english text - Source: Adapted from Gopher (Rae et al., 2021) - """ - - def __init__(self, max_non_alpha_numeric_to_text_ratio=0.25): - super().__init__() - self._cutoff = max_non_alpha_numeric_to_text_ratio - self._name = 'alpha_numeric' - - def score_document(self, text): - nchar = len(text) - if nchar > 0: - score = (nchar - len(regex_alphanum.findall(text))) / nchar - else: - # Remove the document if it is empty - score = 1.0 - return score - - def keep_document(self, score): - return score <= self._cutoff + """ + If more than 25% of the document is non-alphanumeric then discard + Intended to be applied only too english text + Source: Adapted from Gopher (Rae et al., 2021) + """ + + def __init__(self, max_non_alpha_numeric_to_text_ratio=0.25): + super().__init__() + self._cutoff = max_non_alpha_numeric_to_text_ratio + self._name = "alpha_numeric" + + def score_document(self, text): + nchar = len(text) + if nchar > 0: + score = (nchar - len(regex_alphanum.findall(text))) / nchar + else: + # Remove the document if it is empty + score = 1.0 + return score + + def keep_document(self, score): + return score <= self._cutoff class SymbolsToWordsFilter(DocumentFilter): - """ - Remove any document with symbol-to-word ratio greater than - 0.1 for either the hash symbol or the elipsis - Source: Gopher (Rae et al., 2021) - """ - - def __init__(self, max_symbol_to_word_ratio=0.1, lang='en'): - super().__init__() - self._cutoff = max_symbol_to_word_ratio - self._word_splitter = get_word_splitter(lang) - self._name = 'symbol_to_word' - - def score_document(self, text): - num_symbol_words = 0 - words = self._word_splitter(text.strip()) - for word in words: - word = word.strip() - # Checks if the word is an elipsis or consists mostly of symbols. - symbol_ratio = len(regex_hash.findall(word)) / len(word) - if word in ellipsis_marks or symbol_ratio > 0.5: - num_symbol_words += 1 - return num_symbol_words / len(words) - - def keep_document(self, score): - return score <= self._cutoff + """ + Remove any document with symbol-to-word ratio greater than + 0.1 for either the hash symbol or the elipsis + Source: Gopher (Rae et al., 2021) + """ + + def __init__(self, max_symbol_to_word_ratio=0.1, lang="en"): + super().__init__() + self._cutoff = max_symbol_to_word_ratio + self._word_splitter = get_word_splitter(lang) + self._name = "symbol_to_word" + + def score_document(self, text): + num_symbol_words = 0 + words = self._word_splitter(text.strip()) + for word in words: + word = word.strip() + # Checks if the word is an elipsis or consists mostly of symbols. + symbol_ratio = len(regex_hash.findall(word)) / len(word) + if word in ellipsis_marks or symbol_ratio > 0.5: + num_symbol_words += 1 + return num_symbol_words / len(words) + + def keep_document(self, score): + return score <= self._cutoff class NumbersFilter(DocumentFilter): - """ - If more than 15% of the document contains numbers then discard - """ + """ + If more than 15% of the document contains numbers then discard + """ - def __init__(self, max_number_to_text_ratio=0.15): - super().__init__() - self._cutoff = max_number_to_text_ratio - self._name = 'numbers_ratio' + def __init__(self, max_number_to_text_ratio=0.15): + super().__init__() + self._cutoff = max_number_to_text_ratio + self._name = "numbers_ratio" - def score_document(self, text): - nchar = len(text) - if nchar > 0: - score = len(regex_digit.findall(text)) / nchar - else: - # Remove if the document is empty - score = 1.0 - return score + def score_document(self, text): + nchar = len(text) + if nchar > 0: + score = len(regex_digit.findall(text)) / nchar + else: + # Remove if the document is empty + score = 1.0 + return score - def keep_document(self, score): - return score <= self._cutoff + def keep_document(self, score): + return score <= self._cutoff class UrlsFilter(DocumentFilter): - """ - If more than 20% of the document is comprised of URLs then discard - """ - - def __init__(self, max_url_to_text_ratio=0.2): - super().__init__() - self._cutoff = max_url_to_text_ratio - self._name = 'urls_ratio' - - def score_document(self, text): - all_urls = regex_url.findall(text) - url_chars = sum([len(url) for url in all_urls]) - nchar = len(text) - if nchar > 0: - score = url_chars / nchar - else: - # Remove if the document is empty - score = 1.0 - return score - - def keep_document(self, score): - return score <= self._cutoff + """ + If more than 20% of the document is comprised of URLs then discard + """ + + def __init__(self, max_url_to_text_ratio=0.2): + super().__init__() + self._cutoff = max_url_to_text_ratio + self._name = "urls_ratio" + + def score_document(self, text): + all_urls = regex_url.findall(text) + url_chars = sum([len(url) for url in all_urls]) + nchar = len(text) + if nchar > 0: + score = url_chars / nchar + else: + # Remove if the document is empty + score = 1.0 + return score + + def keep_document(self, score): + return score <= self._cutoff class BulletsFilter(DocumentFilter): - """ - If more than 90% of the lines start with a bullet then discard - Source: Gopher (Rae et al., 2021) - """ - - def __init__(self, max_bullet_lines_ratio=0.9): - super().__init__() - self._cutoff = max_bullet_lines_ratio - self._sentences = None - self._name = 'bullet_ratio' - - def score_document(self, text): - # Get sentences - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) - num_bullet_lines = 0 - for sentence in sentences: - for bullet in bullet_list: - if sentence.strip().startswith(bullet): - num_bullet_lines += 1 - break - return num_bullet_lines / len(sentences) - - def keep_document(self, score): - return score <= self._cutoff + """ + If more than 90% of the lines start with a bullet then discard + Source: Gopher (Rae et al., 2021) + """ + + def __init__(self, max_bullet_lines_ratio=0.9): + super().__init__() + self._cutoff = max_bullet_lines_ratio + self._sentences = None + self._name = "bullet_ratio" + + def score_document(self, text): + # Get sentences + sentences = self._sentences + if sentences is None: + sentences = get_sentences(text) + num_bullet_lines = 0 + for sentence in sentences: + for bullet in bullet_list: + if sentence.strip().startswith(bullet): + num_bullet_lines += 1 + break + return num_bullet_lines / len(sentences) + + def keep_document(self, score): + return score <= self._cutoff class WhiteSpaceFilter(DocumentFilter): - """ - If the document contains a significant number - of white space characters then discard - """ - - def __init__(self, max_white_space_ratio=0.25): - super().__init__() - self._cutoff = max_white_space_ratio - self._name = 'white_space' - - def score_document(self, text): - # Do not strip the document since we want to - # include leading and trailing whitepsaces as well. - nchar = len(text) - if nchar > 0: - score = len([x for x in text if x in white_space_list]) / nchar - else: - # Remove if the document is empty - score = 1.0 - return score - - def keep_document(self, score): - return score <= self._cutoff + """ + If the document contains a significant number + of white space characters then discard + """ + + def __init__(self, max_white_space_ratio=0.25): + super().__init__() + self._cutoff = max_white_space_ratio + self._name = "white_space" + + def score_document(self, text): + # Do not strip the document since we want to + # include leading and trailing whitepsaces as well. + nchar = len(text) + if nchar > 0: + score = len([x for x in text if x in white_space_list]) / nchar + else: + # Remove if the document is empty + score = 1.0 + return score + + def keep_document(self, score): + return score <= self._cutoff class ParenthesesFilter(DocumentFilter): - """ - If more than 10% of the sentence is in parentheses then discard - """ + """ + If more than 10% of the sentence is in parentheses then discard + """ - def __init__(self, max_parentheses_ratio=0.1): - super().__init__() - self._max_parentheses_ratio = max_parentheses_ratio - self._name = 'parentheses_ratio' + def __init__(self, max_parentheses_ratio=0.1): + super().__init__() + self._max_parentheses_ratio = max_parentheses_ratio + self._name = "parentheses_ratio" - def score_document(self, text): - nchar = len(text) - if nchar > 0: - score = len(regex_paren.findall(text)) / nchar - else: - # Remove if the document is empty - score = 1.0 - return score + def score_document(self, text): + nchar = len(text) + if nchar > 0: + score = len(regex_paren.findall(text)) / nchar + else: + # Remove if the document is empty + score = 1.0 + return score - def keep_document(self, score): - return score <= self._max_parentheses_ratio + def keep_document(self, score): + return score <= self._max_parentheses_ratio class LongWordFilter(DocumentFilter): - """ - If the document contains a word longer than 1000 characters then discard - NOTE: This seems to be catching things like minified `.js` files - that don't have spaces anywhere. - Source: C4 (Google) - """ + """ + If the document contains a word longer than 1000 characters then discard + NOTE: This seems to be catching things like minified `.js` files + that don't have spaces anywhere. + Source: C4 (Google) + """ - def __init__(self, max_word_length=1000, lang='en'): - super().__init__() - self._max_word_length = max_word_length - self._word_splitter = get_word_splitter(lang) - self._name = 'max_word_length' + def __init__(self, max_word_length=1000, lang="en"): + super().__init__() + self._max_word_length = max_word_length + self._word_splitter = get_word_splitter(lang) + self._name = "max_word_length" - def score_document(self, text): - return max(len(w) for w in self._word_splitter(text.strip())) + def score_document(self, text): + return max(len(w) for w in self._word_splitter(text.strip())) - def keep_document(self, score): - return score <= self._max_word_length + def keep_document(self, score): + return score <= self._max_word_length class WordCountFilter(DocumentFilter): - """ - If a document contains a number of words not - within a specified range then discard - """ + """ + If a document contains a number of words not + within a specified range then discard + """ - def __init__(self, min_words=50, max_words=100000, lang='en'): - super().__init__() - self._min_words = min_words - self._max_words = max_words - self._word_splitter = get_word_splitter(lang) - self._name = 'word_count' + def __init__(self, min_words=50, max_words=100000, lang="en"): + super().__init__() + self._min_words = min_words + self._max_words = max_words + self._word_splitter = get_word_splitter(lang) + self._name = "word_count" - def score_document(self, text): - return len(self._word_splitter(text.strip())) + def score_document(self, text): + return len(self._word_splitter(text.strip())) - def keep_document(self, score): - return self._min_words <= score <= self._max_words + def keep_document(self, score): + return self._min_words <= score <= self._max_words class BoilerPlateStringFilter(DocumentFilter): - """ - If more than 40% of paragraphs contain boilerplate strings then discard. - This includes things like "terms of use", "privacy policy", etc. - Source: Adapted significantly from Google C4 processing. - """ - - def __init__( - self, - remove_if_at_top_or_bottom=True, - max_boilerplate_string_ratio=0.4, - ): - super().__init__() - self._remove_if_at_top_or_bottom = remove_if_at_top_or_bottom - self._max_boilerplate_string_ratio = max_boilerplate_string_ratio - self._boilerplate_paragraph_indices = [] - self._max_ratio = 1.0 - self._name = 'boilerplate_string_ratio' - - def score_document(self, text): - # Initialize variables - boilerplate_paragraph_count = 0 - - # Get the paragraphs - paragraphs = get_paragraphs(text) - - # Check each paragraph - for idx, paragraph in enumerate(paragraphs): - paragraph = paragraph.strip().lower() - if 'lorem ipsum' in paragraph: - return self._max_ratio - if any(p in paragraph for p in policy_substrings): - boilerplate_paragraph_count += 1 - - return boilerplate_paragraph_count / len(paragraphs) - - def keep_document(self, score): - return score <= self._max_boilerplate_string_ratio + """ + If more than 40% of paragraphs contain boilerplate strings then discard. + This includes things like "terms of use", "privacy policy", etc. + Source: Adapted significantly from Google C4 processing. + """ + + def __init__( + self, + remove_if_at_top_or_bottom=True, + max_boilerplate_string_ratio=0.4, + ): + super().__init__() + self._remove_if_at_top_or_bottom = remove_if_at_top_or_bottom + self._max_boilerplate_string_ratio = max_boilerplate_string_ratio + self._boilerplate_paragraph_indices = [] + self._max_ratio = 1.0 + self._name = "boilerplate_string_ratio" + + def score_document(self, text): + # Initialize variables + boilerplate_paragraph_count = 0 + + # Get the paragraphs + paragraphs = get_paragraphs(text) + + # Check each paragraph + for idx, paragraph in enumerate(paragraphs): + paragraph = paragraph.strip().lower() + if "lorem ipsum" in paragraph: + return self._max_ratio + if any(p in paragraph for p in policy_substrings): + boilerplate_paragraph_count += 1 + + return boilerplate_paragraph_count / len(paragraphs) + + def keep_document(self, score): + return score <= self._max_boilerplate_string_ratio class MeanWordLengthFilter(DocumentFilter): - """ - If the mean word length is not in a specified range then discard - """ - - def __init__( - self, - min_mean_word_length=3, - max_mean_word_length=10, - lang='en', - ): - super().__init__() - self._min_cutoff = min_mean_word_length - self._max_cutoff = max_mean_word_length - self._word_splitter = get_word_splitter(lang) - self._name = 'mean_word_length' - - def score_document(self, text): - word_lens = [ - len(w) for w in self._word_splitter(text.strip()) if len(w) > 0 - ] - return sum(word_lens) / len(word_lens) - - def keep_document(self, score): - return self._min_cutoff <= score <= self._max_cutoff + """ + If the mean word length is not in a specified range then discard + """ + + def __init__( + self, + min_mean_word_length=3, + max_mean_word_length=10, + lang="en", + ): + super().__init__() + self._min_cutoff = min_mean_word_length + self._max_cutoff = max_mean_word_length + self._word_splitter = get_word_splitter(lang) + self._name = "mean_word_length" + + def score_document(self, text): + word_lens = [len(w) for w in self._word_splitter(text.strip()) if len(w) > 0] + return sum(word_lens) / len(word_lens) + + def keep_document(self, score): + return self._min_cutoff <= score <= self._max_cutoff class RepeatedLinesFilter(DocumentFilter): - """ - If the document shrinks by > 30% in terms of number of lines after - removing duplicate lines then discard - Source: Gopher (Rae et al., 2021) - """ + """ + If the document shrinks by > 30% in terms of number of lines after + removing duplicate lines then discard + Source: Gopher (Rae et al., 2021) + """ - def __init__(self, max_repeated_line_fraction=0.7): - super().__init__() - self._cutoff = max_repeated_line_fraction - self._name = 'repeated_lines' + def __init__(self, max_repeated_line_fraction=0.7): + super().__init__() + self._cutoff = max_repeated_line_fraction + self._name = "repeated_lines" - def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) - return len(set(sentences)) / len(sentences) + def score_document(self, text): + sentences = self._sentences + if sentences is None: + sentences = get_sentences(text) + return len(set(sentences)) / len(sentences) - def keep_document(self, score): - return score >= self._cutoff + def keep_document(self, score): + return score >= self._cutoff class RepeatedParagraphsFilter(DocumentFilter): - """ - If the document shrinks by > 30% in terms of number of lines after - removing duplicate paragraphs then discard. - Source: Gopher (Rae et al., 2021) - """ + """ + If the document shrinks by > 30% in terms of number of lines after + removing duplicate paragraphs then discard. + Source: Gopher (Rae et al., 2021) + """ - def __init__(self, max_repeated_paragraphs_ratio=0.7): - super().__init__() - self._max_repeated_paragraphs_ratio = max_repeated_paragraphs_ratio - self._name = 'repeated_paragraphs' + def __init__(self, max_repeated_paragraphs_ratio=0.7): + super().__init__() + self._max_repeated_paragraphs_ratio = max_repeated_paragraphs_ratio + self._name = "repeated_paragraphs" - def score_document(self, text): - paragraphs = self._paragraphs - if paragraphs is None: - paragraphs = get_paragraphs(text) - return len(set(paragraphs)) / len(paragraphs) + def score_document(self, text): + paragraphs = self._paragraphs + if paragraphs is None: + paragraphs = get_paragraphs(text) + return len(set(paragraphs)) / len(paragraphs) - def keep_document(self, score): - return score >= self._max_repeated_paragraphs_ratio + def keep_document(self, score): + return score >= self._max_repeated_paragraphs_ratio class RepeatedLinesByCharFilter(DocumentFilter): - """ - If the document shrinks by > 20% in terms of number of lines - after removing duplicate lines then discard - Source: Gopher (Rae et al., 2021) - """ + """ + If the document shrinks by > 20% in terms of number of lines + after removing duplicate lines then discard + Source: Gopher (Rae et al., 2021) + """ - def __init__(self, max_repeated_lines_char_ratio=0.8): - super().__init__() - self._cutoff = max_repeated_lines_char_ratio - self._name = 'repeated_lines_char' + def __init__(self, max_repeated_lines_char_ratio=0.8): + super().__init__() + self._cutoff = max_repeated_lines_char_ratio + self._name = "repeated_lines_char" - def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) + def score_document(self, text): + sentences = self._sentences + if sentences is None: + sentences = get_sentences(text) - return len(''.join(set(sentences))) / len(''.join(sentences)) + return len("".join(set(sentences))) / len("".join(sentences)) - def keep_document(self, score): - return score >= self._cutoff + def keep_document(self, score): + return score >= self._cutoff class RepeatedParagraphsByCharFilter(DocumentFilter): - """ - If the document shrinks by > 10% in terms of number of lines after - removing duplicate paragraphs then discard. - Source: Gopher (Rae et al., 2021) - """ + """ + If the document shrinks by > 10% in terms of number of lines after + removing duplicate paragraphs then discard. + Source: Gopher (Rae et al., 2021) + """ - def __init__(self, max_repeated_paragraphs_char_ratio=0.8): - super().__init__() - self._cutoff = max_repeated_paragraphs_char_ratio - self._name = 'repeated_paragraphs_char' + def __init__(self, max_repeated_paragraphs_char_ratio=0.8): + super().__init__() + self._cutoff = max_repeated_paragraphs_char_ratio + self._name = "repeated_paragraphs_char" - def score_document(self, text): - paragraphs = self._paragraphs - if paragraphs is None: - paragraphs = get_paragraphs(text) + def score_document(self, text): + paragraphs = self._paragraphs + if paragraphs is None: + paragraphs = get_paragraphs(text) - return len(''.join(set(paragraphs))) / len(''.join(paragraphs)) + return len("".join(set(paragraphs))) / len("".join(paragraphs)) - def keep_document(self, score): - return score >= self._cutoff + def keep_document(self, score): + return score >= self._cutoff class RepeatingTopNGramsFilter(DocumentFilter): - """ - If the document shrinks by > x% in terms of number of characters after - removing the top n-grams then discard. - Source: Gopher (Rae et al., 2021) - """ - - def __init__(self, n=2, max_repeating_ngram_ratio=0.2, lang='en'): - super().__init__() - self._n = n - self._cutoff = max_repeating_ngram_ratio - self._max_ratio = 1.0 - self._word_splitter = get_word_splitter(lang) - self._name = f'repeating_top_{n}grams' - - def score_document(self, text): - ngrams = self._ngrams - if ngrams is None: - split_text = self._word_splitter(text.strip()) - if len(split_text) < self._n: - return self._max_ratio - ngrams = get_ngrams(split_text, self._n) - unique_ngrams = set(ngrams) - # Find the most frequent ngram in the zipped ngram list - counts = { - ngram: { - 'freq': 0, - 'num_chars': sum(len(word) for word in ngram) - } for ngram in unique_ngrams - } - for ngram in ngrams: - counts[ngram]['freq'] += 1 - most_frqnt_ngram = ' '.join(max(counts, key=lambda x: counts[x]['freq'])) - # Find the number of characters the most frequent ngram - # contributes to the document - nchar = len(text) - len_diff = nchar - len(text.replace(most_frqnt_ngram, '')) - if nchar > 0: - score = len_diff / nchar - else: - # Remove if the document is empty - score = 1.0 - return score - - def keep_document(self, score): - return score <= self._cutoff + """ + If the document shrinks by > x% in terms of number of characters after + removing the top n-grams then discard. + Source: Gopher (Rae et al., 2021) + """ + + def __init__(self, n=2, max_repeating_ngram_ratio=0.2, lang="en"): + super().__init__() + self._n = n + self._cutoff = max_repeating_ngram_ratio + self._max_ratio = 1.0 + self._word_splitter = get_word_splitter(lang) + self._name = f"repeating_top_{n}grams" + + def score_document(self, text): + ngrams = self._ngrams + if ngrams is None: + split_text = self._word_splitter(text.strip()) + if len(split_text) < self._n: + return self._max_ratio + ngrams = get_ngrams(split_text, self._n) + unique_ngrams = set(ngrams) + # Find the most frequent ngram in the zipped ngram list + counts = { + ngram: {"freq": 0, "num_chars": sum(len(word) for word in ngram)} + for ngram in unique_ngrams + } + for ngram in ngrams: + counts[ngram]["freq"] += 1 + most_frqnt_ngram = " ".join(max(counts, key=lambda x: counts[x]["freq"])) + # Find the number of characters the most frequent ngram + # contributes to the document + nchar = len(text) + len_diff = nchar - len(text.replace(most_frqnt_ngram, "")) + if nchar > 0: + score = len_diff / nchar + else: + # Remove if the document is empty + score = 1.0 + return score + + def keep_document(self, score): + return score <= self._cutoff class RepeatingDuplicateNGramsFilter(DocumentFilter): - """ - If the document shrinks by > x% in terms of number of characters - after removing all duplicate n-grams then discard. - Source: Gopher (Rae et al., 2021) - """ - - def __init__(self, n=2, max_repeating_duplicate_ngram_ratio=0.2, lang='en'): - super().__init__() - self._n = n - self._cutoff = max_repeating_duplicate_ngram_ratio - self._max_ratio = 1.0 - self._word_splitter = get_word_splitter(lang) - self._name = f'repeating_dup_{n}gram' - - def score_document(self, text): - ngrams = self._ngrams - if ngrams is None: - split_text = self._word_splitter(text.strip()) - if len(split_text) < self._n: - return self._max_ratio - ngrams = get_ngrams(split_text, self._n) - - counts = {} - duplicated_nchar = 0 - overlapping_ngrams = 0 - for ngram in ngrams: - counts[ngram] = counts.get(ngram, 0) + 1 - if counts[ngram] > 1: - # Count the number of characters in this ngram that haven't been counted already - duplicated_ngrams = sum(len(gram) for gram in ngram[overlapping_ngrams:]) - # Count the spaces between the ngrams - nspaces = min(self._n - overlapping_ngrams, self._n - 1) - duplicated_nchar += duplicated_ngrams + nspaces - overlapping_ngrams = self._n - overlapping_ngrams = max(overlapping_ngrams - 1, 0) - - nchar = len(text) - if nchar > 0: - score = duplicated_nchar / nchar - else: - # Remove if the document is empty - score = 1.0 - return score - - def keep_document(self, score): - return score <= self._cutoff + """ + If the document shrinks by > x% in terms of number of characters + after removing all duplicate n-grams then discard. + Source: Gopher (Rae et al., 2021) + """ + + def __init__(self, n=2, max_repeating_duplicate_ngram_ratio=0.2, lang="en"): + super().__init__() + self._n = n + self._cutoff = max_repeating_duplicate_ngram_ratio + self._max_ratio = 1.0 + self._word_splitter = get_word_splitter(lang) + self._name = f"repeating_dup_{n}gram" + + def score_document(self, text): + ngrams = self._ngrams + if ngrams is None: + split_text = self._word_splitter(text.strip()) + if len(split_text) < self._n: + return self._max_ratio + ngrams = get_ngrams(split_text, self._n) + + counts = {} + duplicated_nchar = 0 + overlapping_ngrams = 0 + for ngram in ngrams: + counts[ngram] = counts.get(ngram, 0) + 1 + if counts[ngram] > 1: + # Count the number of characters in this ngram that haven't been counted already + duplicated_ngrams = sum( + len(gram) for gram in ngram[overlapping_ngrams:] + ) + # Count the spaces between the ngrams + nspaces = min(self._n - overlapping_ngrams, self._n - 1) + duplicated_nchar += duplicated_ngrams + nspaces + overlapping_ngrams = self._n + overlapping_ngrams = max(overlapping_ngrams - 1, 0) + + nchar = len(text) + if nchar > 0: + score = duplicated_nchar / nchar + else: + # Remove if the document is empty + score = 1.0 + return score + + def keep_document(self, score): + return score <= self._cutoff class PunctuationFilter(DocumentFilter): - """ - If more than 85% of the sentences do not end with a - punctuation mark then discard. - Source: Google C4 processing - """ - - def __init__(self, max_num_sentences_without_endmark_ratio=0.85): - super().__init__() - self._cutoff = max_num_sentences_without_endmark_ratio - self._name = 'punctuation' - - def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) - num_sentence_without_endmarks = len( - [s for s in sentences if not s.strip().endswith(end_marks)]) - return num_sentence_without_endmarks / len(sentences) - - def keep_document(self, score): - return score <= self._cutoff + """ + If more than 85% of the sentences do not end with a + punctuation mark then discard. + Source: Google C4 processing + """ + + def __init__(self, max_num_sentences_without_endmark_ratio=0.85): + super().__init__() + self._cutoff = max_num_sentences_without_endmark_ratio + self._name = "punctuation" + + def score_document(self, text): + sentences = self._sentences + if sentences is None: + sentences = get_sentences(text) + num_sentence_without_endmarks = len( + [s for s in sentences if not s.strip().endswith(end_marks)] + ) + return num_sentence_without_endmarks / len(sentences) + + def keep_document(self, score): + return score <= self._cutoff class EllipsisFilter(DocumentFilter): - """ - If more than 30% of the sentences end with an elipsis then discard. - Source: Google C4 processing - """ - - def __init__(self, max_num_lines_ending_with_ellipsis_ratio=0.3): - super().__init__() - self._cutoff = max_num_lines_ending_with_ellipsis_ratio - self._name = 'ellipsis' - - def score_document(self, text): - sentences = self._sentences - if sentences is None: - sentences = get_sentences(text) - num_lines_ending_with_ellipsis = 0 - for sentence in sentences: - for ellipsis in ellipsis_marks: - if sentence.strip().lower().endswith(ellipsis): - num_lines_ending_with_ellipsis += 1 - break - return num_lines_ending_with_ellipsis / len(sentences) - - def keep_document(self, score): - return score <= self._cutoff + """ + If more than 30% of the sentences end with an elipsis then discard. + Source: Google C4 processing + """ + + def __init__(self, max_num_lines_ending_with_ellipsis_ratio=0.3): + super().__init__() + self._cutoff = max_num_lines_ending_with_ellipsis_ratio + self._name = "ellipsis" + + def score_document(self, text): + sentences = self._sentences + if sentences is None: + sentences = get_sentences(text) + num_lines_ending_with_ellipsis = 0 + for sentence in sentences: + for ellipsis in ellipsis_marks: + if sentence.strip().lower().endswith(ellipsis): + num_lines_ending_with_ellipsis += 1 + break + return num_lines_ending_with_ellipsis / len(sentences) + + def keep_document(self, score): + return score <= self._cutoff class CommonEnglishWordsFilter(DocumentFilter): - """ - If the sentence contains at least 2 common english words, keep - NOTE: we purposefully check for the lowercase versions of those common words - to remove documents with over-capitalization. - """ - - def __init__(self, min_num_common_words=2, stop_at_false=True): - super().__init__() - self._cutoff = min_num_common_words - self._stop_at_false = stop_at_false - self._word_splitter = get_word_splitter('en') - self._name = 'common_english_words' - - def score_document(self, text): - common_word_counter = 0 - for word in self._word_splitter(text.strip()): - if word in common_english_words: - common_word_counter += 1 - if self._stop_at_false and common_word_counter >= self._cutoff: - return common_word_counter + """ + If the sentence contains at least 2 common english words, keep + NOTE: we purposefully check for the lowercase versions of those common words + to remove documents with over-capitalization. + """ + + def __init__(self, min_num_common_words=2, stop_at_false=True): + super().__init__() + self._cutoff = min_num_common_words + self._stop_at_false = stop_at_false + self._word_splitter = get_word_splitter("en") + self._name = "common_english_words" + + def score_document(self, text): + common_word_counter = 0 + for word in self._word_splitter(text.strip()): + if word in common_english_words: + common_word_counter += 1 + if self._stop_at_false and common_word_counter >= self._cutoff: + return common_word_counter - return common_word_counter + return common_word_counter - def keep_document(self, score): - return score >= self._cutoff + def keep_document(self, score): + return score >= self._cutoff class WordsWithoutAlphabetsFilter(DocumentFilter): - """ - 80% of words in a document must contain at least one alphabetic character - Source: Gopher (Rae et al., 2021) - """ + """ + 80% of words in a document must contain at least one alphabetic character + Source: Gopher (Rae et al., 2021) + """ - def __init__(self, min_words_with_alphabets=0.8, lang='en'): - super().__init__() - self._cutoff = min_words_with_alphabets - self._word_splitter = get_word_splitter(lang) - self._name = 'words_without_alphabets' + def __init__(self, min_words_with_alphabets=0.8, lang="en"): + super().__init__() + self._cutoff = min_words_with_alphabets + self._word_splitter = get_word_splitter(lang) + self._name = "words_without_alphabets" - def score_document(self, text): - num_english_alpha = 0 - words = self._word_splitter(text.strip()) - for word in words: - if regex_alpha.search(word): - num_english_alpha += 1 + def score_document(self, text): + num_english_alpha = 0 + words = self._word_splitter(text.strip()) + for word in words: + if regex_alpha.search(word): + num_english_alpha += 1 - return num_english_alpha / len(words) + return num_english_alpha / len(words) - def keep_document(self, score): - return score >= self._cutoff + def keep_document(self, score): + return score >= self._cutoff class PornographicUrlsFilter(DocumentFilter): - """ - Check if any of the urls within the document point to porn - """ + """ + Check if any of the urls within the document point to porn + """ - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - def score_document(self, text): - all_urls = regex_url.findall(text) - for url in all_urls: - if 'porn' in url: - return 1 + def score_document(self, text): + all_urls = regex_url.findall(text) + for url in all_urls: + if "porn" in url: + return 1 - return 0 + return 0 - def keep_document(self, score): - return score != 1 + def keep_document(self, score): + return score != 1 diff --git a/nemo_curator/gpu_deduplication/connected_component.py b/nemo_curator/gpu_deduplication/connected_component.py index d2d8d464..211fb451 100644 --- a/nemo_curator/gpu_deduplication/connected_component.py +++ b/nemo_curator/gpu_deduplication/connected_component.py @@ -26,7 +26,8 @@ from dask.utils import M from nemo_curator.gpu_deduplication.jaccard_utils.doc_id_mapping import ( - convert_str_pair_adlr_ids_to_int,) + convert_str_pair_adlr_ids_to_int, +) from nemo_curator.gpu_deduplication.utils import ( enable_spilling, get_client, @@ -37,262 +38,253 @@ def sort_adlr_id(df): - x = df[["adlr_id_x", "adlr_id_y"]].values - x = cupy.sort(x, axis=1) - df["adlr_id_x"] = x[:, 0] - df["adlr_id_y"] = x[:, 1] - for i in ["adlr_id_x", "adlr_id_y"]: - df[i] = df[i].astype("uint64") - return df + x = df[["adlr_id_x", "adlr_id_y"]].values + x = cupy.sort(x, axis=1) + df["adlr_id_x"] = x[:, 0] + df["adlr_id_y"] = x[:, 1] + for i in ["adlr_id_x", "adlr_id_y"]: + df[i] = df[i].astype("uint64") + return df def thresholding(df, threshold=0.8): - mask = df.jaccard > threshold - df.loc[mask, "jaccard"] = np.int8(1) - df.loc[~mask, "jaccard"] = np.int8(0) - return df + mask = df.jaccard > threshold + df.loc[mask, "jaccard"] = np.int8(1) + df.loc[~mask, "jaccard"] = np.int8(0) + return df @timer def run_connected_components(jaccard_pairs_path, adlr_id_path, output_path): - Comms.initialize(p2p=True) - df = dask_cudf.read_parquet(jaccard_pairs_path, - blocksize="1GB", - aggregate_files=True) - df = df[df["jaccard"] == 1].reset_index(drop=True) - - labels_df = dask_cudf.read_parquet(adlr_id_path) - num_nodes = len(labels_df) - - self_edge_df = labels_df[["uid"]].rename(columns={"uid": "adlr_id_x"}) - self_edge_df["adlr_id_y"] = self_edge_df["adlr_id_x"] - - df = df[["adlr_id_x", "adlr_id_y"]].astype(np.int64) - df = dask_cudf.concat([df, self_edge_df]) - - G = cugraph.MultiGraph(directed=False) - G.from_dask_cudf_edgelist(df, - source="adlr_id_x", - destination="adlr_id_y", - renumber=False) - result = dcg.weakly_connected_components(G) - del G - max_partitions = min(32, result.npartitions) - n_components = len(result[["labels" - ]].drop_duplicates(split_out=max_partitions)) - num_labels = len(result) - print("# of groups", n_components) - print("# of docs removed", num_labels - n_components) - labels_df = labels_df.merge(result, - left_on=["uid"], - right_on=["vertex"], - how="inner") - labels_df = labels_df[["dataset_id", "doc_id", "labels"]] - labels_df = labels_df.rename(columns={"labels": "group"}) - labels_df = labels_df.persist() - # Doing an inner merge above - # should not change any rows - - assert num_nodes == len(labels_df) - print(f"assert num_nodes:{num_nodes}==labels_df:{len(labels_df)} passed") - labels_df.to_parquet(output_path, write_index=False) - Comms.destroy() + Comms.initialize(p2p=True) + df = dask_cudf.read_parquet( + jaccard_pairs_path, blocksize="1GB", aggregate_files=True + ) + df = df[df["jaccard"] == 1].reset_index(drop=True) + + labels_df = dask_cudf.read_parquet(adlr_id_path) + num_nodes = len(labels_df) + + self_edge_df = labels_df[["uid"]].rename(columns={"uid": "adlr_id_x"}) + self_edge_df["adlr_id_y"] = self_edge_df["adlr_id_x"] + + df = df[["adlr_id_x", "adlr_id_y"]].astype(np.int64) + df = dask_cudf.concat([df, self_edge_df]) + + G = cugraph.MultiGraph(directed=False) + G.from_dask_cudf_edgelist( + df, source="adlr_id_x", destination="adlr_id_y", renumber=False + ) + result = dcg.weakly_connected_components(G) + del G + max_partitions = min(32, result.npartitions) + n_components = len(result[["labels"]].drop_duplicates(split_out=max_partitions)) + num_labels = len(result) + print("# of groups", n_components) + print("# of docs removed", num_labels - n_components) + labels_df = labels_df.merge( + result, left_on=["uid"], right_on=["vertex"], how="inner" + ) + labels_df = labels_df[["dataset_id", "doc_id", "labels"]] + labels_df = labels_df.rename(columns={"labels": "group"}) + labels_df = labels_df.persist() + # Doing an inner merge above + # should not change any rows + + assert num_nodes == len(labels_df) + print(f"assert num_nodes:{num_nodes}==labels_df:{len(labels_df)} passed") + labels_df.to_parquet(output_path, write_index=False) + Comms.destroy() def attach_args(parser=None): - description = """Computes connected component""" - if not parser: - parser = parse_nc_args(description=description) - - parser.add_argument( - "--jaccard-pairs-path", - type=str, - help="The directory containing the jaccard results", - ) - parser.add_argument( - "--output-dir", - type=str, - help="The output directory to write results to", - ) - parser.add_argument( - "--cache-dir", - type=str, - help="The cache directory to write intermediate results to", - ) - return parser + description = """Computes connected component""" + if not parser: + parser = parse_nc_args(description=description) + + parser.add_argument( + "--jaccard-pairs-path", + type=str, + help="The directory containing the jaccard results", + ) + parser.add_argument( + "--output-dir", + type=str, + help="The output directory to write results to", + ) + parser.add_argument( + "--cache-dir", + type=str, + help="The cache directory to write intermediate results to", + ) + return parser def delete_cache_data(path): - if "cache" not in path: - return - cmd = f"rm -rf {path}" - print(cmd) - os.system(cmd) + if "cache" not in path: + return + cmd = f"rm -rf {path}" + print(cmd) + os.system(cmd) def write_output(ddf, output_path): - if not isinstance(output_path, str): - assert TypeError(f"output_path should be str. got {type(output_path)}") - print(f"write {output_path} ...") - ddf.to_parquet(output_path, write_index=False) + if not isinstance(output_path, str): + assert TypeError(f"output_path should be str. got {type(output_path)}") + print(f"write {output_path} ...") + ddf.to_parquet(output_path, write_index=False) def get_unique_ids_per_partition(df): - unique_df_ls = [] - for tag in ["x", "y"]: - subset_df = df[[f"dataset_id_{tag}", f"doc_id_{tag}"]].drop_duplicates() - subset_df = subset_df.rename(columns={ - f"dataset_id_{tag}": "dataset_id", - f"doc_id_{tag}": "doc_id" - }) - unique_df_ls.append(subset_df) - unique_df = cudf.concat(unique_df_ls, ignore_index=True) - unique_df = unique_df.drop_duplicates() - return unique_df + unique_df_ls = [] + for tag in ["x", "y"]: + subset_df = df[[f"dataset_id_{tag}", f"doc_id_{tag}"]].drop_duplicates() + subset_df = subset_df.rename( + columns={f"dataset_id_{tag}": "dataset_id", f"doc_id_{tag}": "doc_id"} + ) + unique_df_ls.append(subset_df) + unique_df = cudf.concat(unique_df_ls, ignore_index=True) + unique_df = unique_df.drop_duplicates() + return unique_df @timer def write_dedup_parsed_adlr_id(args): - dedup_parsed_adlr_id_path = f"{args.cache_dir}/dedup_parsed_adlr_id.parquet" - ddf = dask_cudf.read_parquet( - args.jaccard_pairs_path, - columns=["adlr_id_x", "adlr_id_y"], - blocksize="1GB", - aggregate_files=True, - ) - ddf = ddf.map_partitions( - convert_str_pair_adlr_ids_to_int, - meta={ - "dataset_id_x": "uint32", - "doc_id_x": "int64", - "dataset_id_y": "uint32", - "doc_id_y": "int64", - }, - ) - - unique_docs = ddf.map_partitions(get_unique_ids_per_partition) - unique_docs = unique_docs.drop_duplicates(split_out=ddf.npartitions // 4) - unique_docs["uid"] = np.uint64(1) - unique_docs["uid"] = unique_docs["uid"].cumsum() - unique_docs["uid"] = unique_docs["uid"] - 1 - write_output(unique_docs, dedup_parsed_adlr_id_path) - return dedup_parsed_adlr_id_path + dedup_parsed_adlr_id_path = f"{args.cache_dir}/dedup_parsed_adlr_id.parquet" + ddf = dask_cudf.read_parquet( + args.jaccard_pairs_path, + columns=["adlr_id_x", "adlr_id_y"], + blocksize="1GB", + aggregate_files=True, + ) + ddf = ddf.map_partitions( + convert_str_pair_adlr_ids_to_int, + meta={ + "dataset_id_x": "uint32", + "doc_id_x": "int64", + "dataset_id_y": "uint32", + "doc_id_y": "int64", + }, + ) + + unique_docs = ddf.map_partitions(get_unique_ids_per_partition) + unique_docs = unique_docs.drop_duplicates(split_out=ddf.npartitions // 4) + unique_docs["uid"] = np.uint64(1) + unique_docs["uid"] = unique_docs["uid"].cumsum() + unique_docs["uid"] = unique_docs["uid"] - 1 + write_output(unique_docs, dedup_parsed_adlr_id_path) + return dedup_parsed_adlr_id_path def batched_merge_and_write(ddf, ddf_adlr_id, output_path, batch_size=32): - total_batches = (ddf.npartitions + batch_size - 1) // batch_size - for batch_id, offset in enumerate(range(0, ddf.npartitions, batch_size)): - st = time() - subset_ddf = ddf.partitions[offset:offset + batch_size] - for tag in ["x", "y"]: - subset_ddf = subset_ddf.merge( - ddf_adlr_id, - left_on=[f"dataset_id_{tag}", f"doc_id_{tag}"], - right_on=["dataset_id", "doc_id"], - how="inner", - broadcast=True, - ) - subset_ddf = subset_ddf.rename(columns={"uid": f"adlr_id_{tag}"}) - subset_ddf = subset_ddf.drop( - columns=[f"dataset_id_{tag}", f"doc_id_{tag}"]) - - subset_ddf = subset_ddf[["adlr_id_x", "adlr_id_y", "jaccard"]] - output_batch_path = os.path.join(output_path, f"{batch_id}.parquet") - subset_ddf.to_parquet(output_batch_path, write_index=False) - - et = time() - print(f"batch_id = {batch_id}/{total_batches}, time = {et - st}", - flush=True) + total_batches = (ddf.npartitions + batch_size - 1) // batch_size + for batch_id, offset in enumerate(range(0, ddf.npartitions, batch_size)): + st = time() + subset_ddf = ddf.partitions[offset : offset + batch_size] + for tag in ["x", "y"]: + subset_ddf = subset_ddf.merge( + ddf_adlr_id, + left_on=[f"dataset_id_{tag}", f"doc_id_{tag}"], + right_on=["dataset_id", "doc_id"], + how="inner", + broadcast=True, + ) + subset_ddf = subset_ddf.rename(columns={"uid": f"adlr_id_{tag}"}) + subset_ddf = subset_ddf.drop(columns=[f"dataset_id_{tag}", f"doc_id_{tag}"]) + + subset_ddf = subset_ddf[["adlr_id_x", "adlr_id_y", "jaccard"]] + output_batch_path = os.path.join(output_path, f"{batch_id}.parquet") + subset_ddf.to_parquet(output_batch_path, write_index=False) + + et = time() + print(f"batch_id = {batch_id}/{total_batches}, time = {et - st}", flush=True) @timer def write_encoded_jaccard_pair(args, client): - dedup_parsed_adlr_id_path = f"{args.cache_dir}/dedup_parsed_adlr_id.parquet" - output_path = f"{args.cache_dir}/encoded_jaccard_pair/" - ddf_adlr_id = dask_cudf.read_parquet(dedup_parsed_adlr_id_path, - blocksize="2GB", - aggregate_files=True) - ddf_adlr_id = ddf_adlr_id.persist() - len(ddf_adlr_id) - ddf = dask_cudf.read_parquet( - args.jaccard_pairs_path, - blocksize="256MB", - aggregate_files=True, - ) - ddf = ddf.map_partitions( - convert_str_pair_adlr_ids_to_int, - meta={ - "jaccard": "float32", - "dataset_id_x": "uint32", - "doc_id_x": "int64", - "dataset_id_y": "uint32", - "doc_id_y": "int64", - }, - ) - num_workers = get_num_workers(client) - batched_merge_and_write(ddf, ddf_adlr_id, output_path, num_workers) + dedup_parsed_adlr_id_path = f"{args.cache_dir}/dedup_parsed_adlr_id.parquet" + output_path = f"{args.cache_dir}/encoded_jaccard_pair/" + ddf_adlr_id = dask_cudf.read_parquet( + dedup_parsed_adlr_id_path, blocksize="2GB", aggregate_files=True + ) + ddf_adlr_id = ddf_adlr_id.persist() + len(ddf_adlr_id) + ddf = dask_cudf.read_parquet( + args.jaccard_pairs_path, + blocksize="256MB", + aggregate_files=True, + ) + ddf = ddf.map_partitions( + convert_str_pair_adlr_ids_to_int, + meta={ + "jaccard": "float32", + "dataset_id_x": "uint32", + "doc_id_x": "int64", + "dataset_id_y": "uint32", + "doc_id_y": "int64", + }, + ) + num_workers = get_num_workers(client) + batched_merge_and_write(ddf, ddf_adlr_id, output_path, num_workers) @timer def write_dedup_encoded_jaccard_pair(args, client): - input_path = f"{args.cache_dir}/encoded_jaccard_pair" - output_path = f"{args.cache_dir}/final_dedup_encoded_jaccard_pair.parquet" - - ddf = dask_cudf.read_parquet(input_path, - blocksize="512MB", - aggregate_files=True) - meta = {"adlr_id_x": "uint64", "adlr_id_y": "uint64", "jaccard": "float32"} - ddf = ddf.map_partitions(sort_adlr_id, meta=meta) - ddf = ddf.map_partitions(thresholding, meta=meta) - ddf = ddf.map_partitions( - M.drop_duplicates, - meta=ddf._meta, - enforce_metadata=False, - transform_divisions=False, - align_dataframes=False, - ) - ddf = dd_shuffle( - ddf, - ["adlr_id_x", "doc_id"], - ignore_index=True, - shuffle="tasks", - ) - ddf = ddf.map_partitions( - M.drop_duplicates, - meta=ddf._meta, - enforce_metadata=False, - transform_divisions=False, - align_dataframes=False, - ) - - write_output(ddf, output_path) - return output_path + input_path = f"{args.cache_dir}/encoded_jaccard_pair" + output_path = f"{args.cache_dir}/final_dedup_encoded_jaccard_pair.parquet" + + ddf = dask_cudf.read_parquet(input_path, blocksize="512MB", aggregate_files=True) + meta = {"adlr_id_x": "uint64", "adlr_id_y": "uint64", "jaccard": "float32"} + ddf = ddf.map_partitions(sort_adlr_id, meta=meta) + ddf = ddf.map_partitions(thresholding, meta=meta) + ddf = ddf.map_partitions( + M.drop_duplicates, + meta=ddf._meta, + enforce_metadata=False, + transform_divisions=False, + align_dataframes=False, + ) + ddf = dd_shuffle( + ddf, + ["adlr_id_x", "doc_id"], + ignore_index=True, + shuffle="tasks", + ) + ddf = ddf.map_partitions( + M.drop_duplicates, + meta=ddf._meta, + enforce_metadata=False, + transform_divisions=False, + align_dataframes=False, + ) + + write_output(ddf, output_path) + return output_path def main(args): - description = """Takes a dataset consisting of document pairs + description = """Takes a dataset consisting of document pairs and their corresponding jaccard similarity to compute connected components of docuements across pairs to find similar docuemnt after applying a given threshold. The result is a dataset consisting of all documents that are similar (above the threshold) and the component they belong to.""" - start = time() - output_path = os.path.join(args.output_dir, "connected_components.parquet") + start = time() + output_path = os.path.join(args.output_dir, "connected_components.parquet") - client = get_client(args) - enable_spilling() - client.run(enable_spilling) - adlr_id_path = write_dedup_parsed_adlr_id(args) - write_encoded_jaccard_pair(args, client) - jaccard_pairs_path = write_dedup_encoded_jaccard_pair(args, client) - run_connected_components(jaccard_pairs_path, adlr_id_path, output_path) - print(f"All done in {time()-start:.1f} seconds") + client = get_client(args) + enable_spilling() + client.run(enable_spilling) + adlr_id_path = write_dedup_parsed_adlr_id(args) + write_encoded_jaccard_pair(args, client) + jaccard_pairs_path = write_dedup_encoded_jaccard_pair(args, client) + run_connected_components(jaccard_pairs_path, adlr_id_path, output_path) + print(f"All done in {time()-start:.1f} seconds") def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) if __name__ == "__main__": - main(attach_args().parse_args()) - \ No newline at end of file + main(attach_args().parse_args()) diff --git a/nemo_curator/gpu_deduplication/ioutils.py b/nemo_curator/gpu_deduplication/ioutils.py index 88c18a17..7ac253c0 100644 --- a/nemo_curator/gpu_deduplication/ioutils.py +++ b/nemo_curator/gpu_deduplication/ioutils.py @@ -20,55 +20,52 @@ from dask import dataframe as dd from tqdm import tqdm + # TODO: # Combine this with # nemo_curator.distributed_utils.read_cudf_jsonl -def read_json_func(files, - engine="cudf", - include_path_column=False, - columns=None): - """ +def read_json_func(files, engine="cudf", include_path_column=False, columns=None): + """ Reads multiple Json Lines files into a cuDF dataframe with an additional `path` column denoting the path of the input file. """ - if not include_path_column: - if columns: - return cudf.read_json(files, engine="cudf", lines=True)[columns] - else: - return cudf.read_json(files, engine="cudf", lines=True) + if not include_path_column: + if columns: + return cudf.read_json(files, engine="cudf", lines=True)[columns] + else: + return cudf.read_json(files, engine="cudf", lines=True) - dfs = [] - for file in files: - if columns: - df = cudf.read_json(file, engine=engine, lines=True)[columns] - else: - df = cudf.read_json(file, engine=engine, lines=True) - df["path"] = file - dfs.append(df) - return cudf.concat(dfs, ignore_index=True) + dfs = [] + for file in files: + if columns: + df = cudf.read_json(file, engine=engine, lines=True)[columns] + else: + df = cudf.read_json(file, engine=engine, lines=True) + df["path"] = file + dfs.append(df) + return cudf.concat(dfs, ignore_index=True) def bucketed_read(files, func=read_json_func, b_size=2, meta=None, **kwargs): - """ + """ Read files with `b_size` number of files per bucket. Users can specify their own read """ - filepaths = [ - files[i:i + b_size] for i in range(0, len(files), b_size) # noqa: E203 - ] - if meta: - return dd.from_map(func, filepaths, meta=meta, **kwargs) - else: - return dd.from_map(func, filepaths, **kwargs) + filepaths = [ + files[i : i + b_size] for i in range(0, len(files), b_size) # noqa: E203 + ] + if meta: + return dd.from_map(func, filepaths, meta=meta, **kwargs) + else: + return dd.from_map(func, filepaths, **kwargs) -#TODO: Remove this function +# TODO: Remove this function def regular_read_json(files, include_path_column=False): - return dask_cudf.read_json(files, - engine="cudf", - lines=True, - include_path_column=include_path_column) + return dask_cudf.read_json( + files, engine="cudf", lines=True, include_path_column=include_path_column + ) def batched_writing( @@ -77,7 +74,7 @@ def batched_writing( partition_on: Sequence[str], parts_ber_batch: int = 32, ): - """ + """ Write a dask dataframe to parquet in batches. This allows us to do batched exectution and prevent OOMs Args: @@ -87,32 +84,33 @@ def batched_writing( parts_ber_batch: number of partitions per batch """ - total_partitions = dask_df.npartitions - for batch_id, part_offset in tqdm( - enumerate(range(0, dask_df.npartitions, parts_ber_batch))): - print(f"\nStarted processing batch in = {batch_id}", flush=True) - df = dask_df.partitions[part_offset:part_offset + parts_ber_batch] - if partition_on: - df.to_parquet( - output_path, - partition_on=partition_on, - name_function=lambda x: f"batch_{batch_id}_part_{x}.parquet", - write_metadata_file=False, - ) - else: - df.to_parquet( - output_path, - name_function=lambda x: f"batch_{batch_id}_part_{x}.parquet", - write_metadata_file=False, - ) - print( - f"Part {part_offset+parts_ber_batch}/{total_partitions} completed", - flush=True, - ) + total_partitions = dask_df.npartitions + for batch_id, part_offset in tqdm( + enumerate(range(0, dask_df.npartitions, parts_ber_batch)) + ): + print(f"\nStarted processing batch in = {batch_id}", flush=True) + df = dask_df.partitions[part_offset : part_offset + parts_ber_batch] + if partition_on: + df.to_parquet( + output_path, + partition_on=partition_on, + name_function=lambda x: f"batch_{batch_id}_part_{x}.parquet", + write_metadata_file=False, + ) + else: + df.to_parquet( + output_path, + name_function=lambda x: f"batch_{batch_id}_part_{x}.parquet", + write_metadata_file=False, + ) + print( + f"Part {part_offset+parts_ber_batch}/{total_partitions} completed", + flush=True, + ) def strip_trailing_sep(path: str): - """ + """ Strips a path string of trailing path seperators like `/` if any. """ - return path.rstrip(os.path.sep) + return path.rstrip(os.path.sep) diff --git a/nemo_curator/gpu_deduplication/jaccard_compute.py b/nemo_curator/gpu_deduplication/jaccard_compute.py index 674df14c..f90e6c44 100644 --- a/nemo_curator/gpu_deduplication/jaccard_compute.py +++ b/nemo_curator/gpu_deduplication/jaccard_compute.py @@ -20,7 +20,8 @@ import numpy as np from nemo_curator.gpu_deduplication.jaccard_utils.jaccard_similarity_utils import ( - compute_jaccard_and_create_pair_df,) + compute_jaccard_and_create_pair_df, +) from nemo_curator.gpu_deduplication.utils import ( enable_spilling, get_client, @@ -30,121 +31,124 @@ def create_bins(path_dicts, max_size): - path_dicts.sort(key=lambda x: x["str_bytes"], reverse=True) - bins, bin_sizes = [], [] - for path_d in path_dicts: - new_path, new_size = path_d["path"], path_d["str_bytes"] - for i, bin_size in enumerate(bin_sizes): - if bin_size + new_size <= max_size: - bins[i].append(new_path) - bin_sizes[i] += new_size - new_size = 0 - break - if new_size: - bins.append([new_path]) - bin_sizes.append(new_size) - return bins + path_dicts.sort(key=lambda x: x["str_bytes"], reverse=True) + bins, bin_sizes = [], [] + for path_d in path_dicts: + new_path, new_size = path_d["path"], path_d["str_bytes"] + for i, bin_size in enumerate(bin_sizes): + if bin_size + new_size <= max_size: + bins[i].append(new_path) + bin_sizes[i] += new_size + new_size = 0 + break + if new_size: + bins.append([new_path]) + bin_sizes.append(new_size) + return bins def get_anchor_docs_and_string_size(path): - df = cudf.read_parquet(path) - str_bytes = df["text"].str.byte_count().sum() - is_anchor_flag = (df["adlr_id"] == df["anchor_1_adlr_id"]) | ( - df["adlr_id"] == df["anchor_0_adlr_id"]) - anchor_df = df[is_anchor_flag].reset_index(drop=True) - return anchor_df, {"path": path, "str_bytes": str_bytes} + df = cudf.read_parquet(path) + str_bytes = df["text"].str.byte_count().sum() + is_anchor_flag = (df["adlr_id"] == df["anchor_1_adlr_id"]) | ( + df["adlr_id"] == df["anchor_0_adlr_id"] + ) + anchor_df = df[is_anchor_flag].reset_index(drop=True) + return anchor_df, {"path": path, "str_bytes": str_bytes} def compute_jaccard_on_1_partition(path): - try: - df = cudf.read_parquet(path) - pair_df = compute_jaccard_and_create_pair_df(df) - except OverflowError: - paths = [entry.path for entry in os.scandir(os.path.join(path))] - anchor_df_str_size_ls = [ - get_anchor_docs_and_string_size(path) for path in paths - ] - anchor_df = cudf.concat( - [anchor_doc for anchor_doc, _ in anchor_df_str_size_ls], - ignore_index=True).drop_duplicates() - df_str_size = [str_size for _, str_size in anchor_df_str_size_ls] - paths = create_bins(df_str_size, np.iinfo(np.int32).max // 10) - pair_dfs = [] - for path in paths: - print(path) - df = cudf.read_parquet(path).reset_index(drop=True) - df = cudf.concat([df, anchor_df], ignore_index=True) - pair_df = compute_jaccard_and_create_pair_df(df) - pair_dfs.append(pair_df) - pair_df = cudf.concat(pair_dfs, ignore_index=True) - return pair_df + try: + df = cudf.read_parquet(path) + pair_df = compute_jaccard_and_create_pair_df(df) + except OverflowError: + paths = [entry.path for entry in os.scandir(os.path.join(path))] + anchor_df_str_size_ls = [ + get_anchor_docs_and_string_size(path) for path in paths + ] + anchor_df = cudf.concat( + [anchor_doc for anchor_doc, _ in anchor_df_str_size_ls], ignore_index=True + ).drop_duplicates() + df_str_size = [str_size for _, str_size in anchor_df_str_size_ls] + paths = create_bins(df_str_size, np.iinfo(np.int32).max // 10) + pair_dfs = [] + for path in paths: + print(path) + df = cudf.read_parquet(path).reset_index(drop=True) + df = cudf.concat([df, anchor_df], ignore_index=True) + pair_df = compute_jaccard_and_create_pair_df(df) + pair_dfs.append(pair_df) + pair_df = cudf.concat(pair_dfs, ignore_index=True) + return pair_df def run_jaccard_compute(shuffled_docs_path, output_final_results_path): - print("Starting Jaccard Computation", flush=True) - st = time.time() - paths = [ - entry.path - for entry in os.scandir(shuffled_docs_path) - if not entry.path.endswith(".txt") - ] - meta_df = cudf.DataFrame({ - "adlr_id_x": ["x"], - "adlr_id_y": ["y"], - "jaccard": np.float32([0.0]), - }) - result_df = dd.from_map(compute_jaccard_on_1_partition, paths, - meta=meta_df).reset_index(drop=True) - - result_df.to_parquet( - output_final_results_path, - write_index=False, - write_metadata_file=False, - ) - print(f"Jaccard Computing+Writing time: {time.time() - st:.1f} seconds") + print("Starting Jaccard Computation", flush=True) + st = time.time() + paths = [ + entry.path + for entry in os.scandir(shuffled_docs_path) + if not entry.path.endswith(".txt") + ] + meta_df = cudf.DataFrame( + { + "adlr_id_x": ["x"], + "adlr_id_y": ["y"], + "jaccard": np.float32([0.0]), + } + ) + result_df = dd.from_map( + compute_jaccard_on_1_partition, paths, meta=meta_df + ).reset_index(drop=True) + + result_df.to_parquet( + output_final_results_path, + write_index=False, + write_metadata_file=False, + ) + print(f"Jaccard Computing+Writing time: {time.time() - st:.1f} seconds") def main(args): - description = """Computes the Jaccard similarity between document pairs + description = """Computes the Jaccard similarity between document pairs from partitioned parquet dataset. Result is a parquet dataset consiting of document id pair along with their Jaccard similarity score. """ - OUTPUT_PATH = args.output_dir - shuffled_docs_path = args.shuffled_docs_path - output_final_results_path = os.path.join(OUTPUT_PATH, - "dedup_final_results.parquet") - client = get_client(args) - enable_spilling() - client.run(enable_spilling) - print(f"Num Workers = {get_num_workers(client)}", flush=True) - print("Connected to dask cluster", flush=True) - print("Running jaccard compute script", flush=True) + OUTPUT_PATH = args.output_dir + shuffled_docs_path = args.shuffled_docs_path + output_final_results_path = os.path.join(OUTPUT_PATH, "dedup_final_results.parquet") + client = get_client(args) + enable_spilling() + client.run(enable_spilling) + print(f"Num Workers = {get_num_workers(client)}", flush=True) + print("Connected to dask cluster", flush=True) + print("Running jaccard compute script", flush=True) - # Run actual computation - run_jaccard_compute(shuffled_docs_path, output_final_results_path) + # Run actual computation + run_jaccard_compute(shuffled_docs_path, output_final_results_path) def attach_args(parser=None): - description = """Computes jaccard similarity""" - if not parser: - parser = parse_nc_args(description=description) - - parser.add_argument( - "--shuffled-docs-path", - type=str, - help="The directory containing the shuffled documents", - ) - parser.add_argument( - "--output-dir", - type=str, - help="The output directory to write results to", - ) - return parser + description = """Computes jaccard similarity""" + if not parser: + parser = parse_nc_args(description=description) + + parser.add_argument( + "--shuffled-docs-path", + type=str, + help="The directory containing the shuffled documents", + ) + parser.add_argument( + "--output-dir", + type=str, + help="The output directory to write results to", + ) + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) if __name__ == "__main__": - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/gpu_deduplication/jaccard_map_buckets.py b/nemo_curator/gpu_deduplication/jaccard_map_buckets.py index 092d9225..aa60787d 100644 --- a/nemo_curator/gpu_deduplication/jaccard_map_buckets.py +++ b/nemo_curator/gpu_deduplication/jaccard_map_buckets.py @@ -19,9 +19,11 @@ from dask.utils import M from nemo_curator.gpu_deduplication.jaccard_utils.get_anchor_utils import ( - add_anchor_docs,) + add_anchor_docs, +) from nemo_curator.gpu_deduplication.jaccard_utils.get_output_map_utils import ( - get_output_map_based_on_str_bytes,) + get_output_map_based_on_str_bytes, +) from nemo_curator.gpu_deduplication.jaccard_utils.io_utils import ( get_bucket_ddf_from_parquet_path, get_text_ddf_from_json_path_with_blocksize, @@ -41,7 +43,7 @@ def get_anchor_and_output_map_info( num_workers, shuffle_type, ): - """ + """ Get anchor docs with bucket info Args: input_data_paths: list of paths to input data @@ -53,75 +55,76 @@ def get_anchor_and_output_map_info( Returns: ddf_anchor_docs_with_bk """ - ddf_text = get_text_ddf_from_json_path_with_blocksize( - input_data_paths=input_data_paths, - num_files=num_files, - blocksize=text_ddf_blocksize, - ) - ddf_bk = get_bucket_ddf_from_parquet_path(input_bucket_path=input_bucket_path, - num_workers=num_workers) - output_map_df = get_output_map_based_on_str_bytes(ddf_bk=ddf_bk, - ddf_text=ddf_text) - ddf_anchor_docs_with_bk = ddf_bk.map_partitions(add_anchor_docs) - print("output_map_df is based on string bytes", flush=True) - ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.merge(output_map_df, - on=["bucket"]) - # Bucket is no longer needed - ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.drop(columns=["bucket"]) - # Below removes any duplicates lying around after dropping buckets - ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.map_partitions( - M.drop_duplicates, - meta=ddf_anchor_docs_with_bk._meta, - enforce_metadata=False, - transform_divisions=False, - align_dataframes=False, - ) - ddf_anchor_docs_with_bk = dd_shuffle( - ddf_anchor_docs_with_bk, - ["dataset_id", "doc_id"], - ignore_index=True, - shuffle=shuffle_type, - ).map_partitions( - M.drop_duplicates, - meta=ddf_anchor_docs_with_bk._meta, - enforce_metadata=False, - transform_divisions=False, - align_dataframes=False, - ) - del output_map_df - return ddf_anchor_docs_with_bk + ddf_text = get_text_ddf_from_json_path_with_blocksize( + input_data_paths=input_data_paths, + num_files=num_files, + blocksize=text_ddf_blocksize, + ) + ddf_bk = get_bucket_ddf_from_parquet_path( + input_bucket_path=input_bucket_path, num_workers=num_workers + ) + output_map_df = get_output_map_based_on_str_bytes(ddf_bk=ddf_bk, ddf_text=ddf_text) + ddf_anchor_docs_with_bk = ddf_bk.map_partitions(add_anchor_docs) + print("output_map_df is based on string bytes", flush=True) + ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.merge( + output_map_df, on=["bucket"] + ) + # Bucket is no longer needed + ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.drop(columns=["bucket"]) + # Below removes any duplicates lying around after dropping buckets + ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.map_partitions( + M.drop_duplicates, + meta=ddf_anchor_docs_with_bk._meta, + enforce_metadata=False, + transform_divisions=False, + align_dataframes=False, + ) + ddf_anchor_docs_with_bk = dd_shuffle( + ddf_anchor_docs_with_bk, + ["dataset_id", "doc_id"], + ignore_index=True, + shuffle=shuffle_type, + ).map_partitions( + M.drop_duplicates, + meta=ddf_anchor_docs_with_bk._meta, + enforce_metadata=False, + transform_divisions=False, + align_dataframes=False, + ) + del output_map_df + return ddf_anchor_docs_with_bk def attach_args(parser=None): - description = """Takes the buckets generated from minhashes and uses + description = """Takes the buckets generated from minhashes and uses document length information to create a coarse mapping of mapping multiple buckets to a logical partition by using a modified bin packing algorithm. """ - if not parser: - parser = parse_nc_args(description=description) - parser.add_argument( - "--input-bucket-dir", - type=str, - help="The directory containing bucket information files", - ) - parser.add_argument( - "--text-ddf-blocksize", - type=int, - default=256, - help="The block size for chunking jsonl files for text ddf in mb", - ) - parser.add_argument( - "--output-dir", - type=str, - help="The output directory to write results in", - ) - parser.add_argument( - "--shuffle-type", - type=str, - default="tasks", - help="Type of shuffle to use before writing to parquet", - ) - return parser + if not parser: + parser = parse_nc_args(description=description) + parser.add_argument( + "--input-bucket-dir", + type=str, + help="The directory containing bucket information files", + ) + parser.add_argument( + "--text-ddf-blocksize", + type=int, + default=256, + help="The block size for chunking jsonl files for text ddf in mb", + ) + parser.add_argument( + "--output-dir", + type=str, + help="The output directory to write results in", + ) + parser.add_argument( + "--shuffle-type", + type=str, + default="tasks", + help="Type of shuffle to use before writing to parquet", + ) + return parser def jaccard_get_output_map_workflow( @@ -133,7 +136,7 @@ def jaccard_get_output_map_workflow( num_files, shuffle_type, ): - """ + """ Workflow for jaccard shuffle Args: client: dask client @@ -145,50 +148,50 @@ def jaccard_get_output_map_workflow( parts_per_worker: number of parts per worker shuffle_type: type of shuffle to use before writing to parquet """ - num_workers = get_num_workers(client) - ddf_anchor_docs_with_bk = get_anchor_and_output_map_info( - input_data_paths, - input_bucket_path, - text_ddf_blocksize, - num_files, - num_workers, - shuffle_type, - ) - ddf_anchor_docs_with_bk.to_parquet( - output_anchor_docs_with_bk_path, - write_index=False, - ) + num_workers = get_num_workers(client) + ddf_anchor_docs_with_bk = get_anchor_and_output_map_info( + input_data_paths, + input_bucket_path, + text_ddf_blocksize, + num_files, + num_workers, + shuffle_type, + ) + ddf_anchor_docs_with_bk.to_parquet( + output_anchor_docs_with_bk_path, + write_index=False, + ) def main(args): - input_data_paths = args.input_data_dirs - input_bucket_path = args.input_bucket_dir - OUTPUT_PATH = args.output_dir - output_anchor_docs_with_bk_path = os.path.join(OUTPUT_PATH, - "anchor_docs_with_bk.parquet") - client = get_client(args) - print(f"Num Workers = {get_num_workers(client)}", flush=True) - print("Connected to dask cluster", flush=True) - print("Running jaccard map buckets script", flush=True) - print(f"Args = {args}") - st = time.time() - jaccard_get_output_map_workflow( - client, - input_data_paths, - input_bucket_path, - output_anchor_docs_with_bk_path, - args.text_ddf_blocksize, - args.num_files, - args.shuffle_type, - ) - et = time.time() - print(f"Bucket Mapping time taken = {et-st} s") + input_data_paths = args.input_data_dirs + input_bucket_path = args.input_bucket_dir + OUTPUT_PATH = args.output_dir + output_anchor_docs_with_bk_path = os.path.join( + OUTPUT_PATH, "anchor_docs_with_bk.parquet" + ) + client = get_client(args) + print(f"Num Workers = {get_num_workers(client)}", flush=True) + print("Connected to dask cluster", flush=True) + print("Running jaccard map buckets script", flush=True) + print(f"Args = {args}") + st = time.time() + jaccard_get_output_map_workflow( + client, + input_data_paths, + input_bucket_path, + output_anchor_docs_with_bk_path, + args.text_ddf_blocksize, + args.num_files, + args.shuffle_type, + ) + et = time.time() + print(f"Bucket Mapping time taken = {et-st} s") def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) if __name__ == "__main__": - main(attach_args().parse_args()) - \ No newline at end of file + main(attach_args().parse_args()) diff --git a/nemo_curator/gpu_deduplication/jaccard_shuffle.py b/nemo_curator/gpu_deduplication/jaccard_shuffle.py index 3f1d7f5e..846d30c4 100644 --- a/nemo_curator/gpu_deduplication/jaccard_shuffle.py +++ b/nemo_curator/gpu_deduplication/jaccard_shuffle.py @@ -19,9 +19,11 @@ from tqdm import tqdm from nemo_curator.gpu_deduplication.jaccard_utils.batch_shuffle_utils import ( - text_bytes_aware_shuffle,) + text_bytes_aware_shuffle, +) from nemo_curator.gpu_deduplication.jaccard_utils.doc_id_mapping import ( - combine_back_adlr_ids,) + combine_back_adlr_ids, +) from nemo_curator.gpu_deduplication.jaccard_utils.io_utils import ( aggregated_anchor_docs_with_bk_read, get_restart_offsets, @@ -42,17 +44,17 @@ def write_partitioned_file(df, output_path, partition_on, batch_id): - if len(df) == 0: + if len(df) == 0: + return cudf.Series([True]) + + cudf.io.parquet.write_to_dataset( + df, + output_path, + partition_cols=[partition_on], + filename=f"batch_{batch_id}.parquet", + ) return cudf.Series([True]) - cudf.io.parquet.write_to_dataset( - df, - output_path, - partition_cols=[partition_on], - filename=f"batch_{batch_id}.parquet", - ) - return cudf.Series([True]) - def batched_merge_and_write( left_df, @@ -66,171 +68,172 @@ def batched_merge_and_write( num_workers=None, ): - total_text_partitions = left_df.npartitions - total_bucket_partitions = right_df.npartitions - - # Extract global partitioning index - left_df, global_partitioning_index = extract_partitioning_index( - left_df, - merge_on, - bk_mapping, - parts_per_bucket_batch, - total_bucket_partitions, - ) - - # Set start offsets - bucket_part_start_offset, text_part_start_offset = get_restart_offsets( - output_path) - - # Set end offsets - # NOTE: These end offsets are always set to the end - # of the data. However, we may want to be able to set - # both the start and end offsets from the command line - # in the future. - bucket_part_end_offset = total_bucket_partitions - text_part_end_offset = total_text_partitions - - # Check that offsets are valid - assert bucket_part_start_offset % parts_per_bucket_batch == 0 - assert bucket_part_end_offset > bucket_part_start_offset - assert text_part_end_offset > text_part_start_offset - - # Initialize "retry" variables - # - # - retry_count: The number of successive batches that - # we have already performed at a reduced batch size. - # - retry_threshold: The number of successive batches - # for which we should keep the batch size low - # before attempting the default batch size again. - # Every time we return to the default batch size - # and immediately fail, retry_threshold will double. - parts_per_text_batch_retry = None - retry_count, retry_threshold = 0, 1 - - print( - f"Starting at bucket-map partition {bucket_part_start_offset}" - f" and text-df partition {text_part_start_offset}", - flush=True, - ) - - for bucket_part_offset in tqdm( - range(bucket_part_start_offset, bucket_part_end_offset, - parts_per_bucket_batch)): - - # Outer loop over batches of "bucket-map" partitions - end_bucket_offset = min(bucket_part_offset + parts_per_bucket_batch, - bucket_part_end_offset) - print( - f"\nStarted processing bucket-map partitions {bucket_part_offset} " - f"through {end_bucket_offset} of {bucket_part_end_offset}", - flush=True, - ) - st_bucket = time.time() - - # Select our bucket-mapping batch - subset_bucket_df = right_df.partitions[bucket_part_offset:end_bucket_offset] - subset_bucket_df = subset_bucket_df.persist() + total_text_partitions = left_df.npartitions + total_bucket_partitions = right_df.npartitions - # Filter out rows of left_df that we know cannot - # align with any rows of subset_bucket_df - left_df_use = filter_text_rows_by_bucket_batch( + # Extract global partitioning index + left_df, global_partitioning_index = extract_partitioning_index( left_df, - global_partitioning_index, - bucket_part_offset, - bucket_part_end_offset, + merge_on, + bk_mapping, + parts_per_bucket_batch, total_bucket_partitions, ) - text_part_offset = text_part_start_offset - while text_part_offset < text_part_end_offset: - - # Check if we are "retrying" with a smaller "parts_per_text_batch" - if parts_per_text_batch_retry: - parts_per_text_batch_use = parts_per_text_batch_retry - else: - st_text = time.time() - parts_per_text_batch_use = parts_per_text_batch - print(f"Using {parts_per_text_batch_use} text partitions.", flush=True) - - # Select partitions for our text batch - end_text_offset = min(text_part_offset + parts_per_text_batch_use, - text_part_end_offset) - subset_text_df = left_df_use.partitions[text_part_offset:end_text_offset] - - try: - # NOTE: If we have more text-df partitions than bucket-map - # partitions, we are more likely to see an OverflowError - output_df = text_bytes_aware_shuffle( - merge_left_to_shuffled_right( - subset_text_df, - subset_bucket_df, - merge_on, - ), - partition_on, - num_workers=num_workers, + # Set start offsets + bucket_part_start_offset, text_part_start_offset = get_restart_offsets(output_path) + + # Set end offsets + # NOTE: These end offsets are always set to the end + # of the data. However, we may want to be able to set + # both the start and end offsets from the command line + # in the future. + bucket_part_end_offset = total_bucket_partitions + text_part_end_offset = total_text_partitions + + # Check that offsets are valid + assert bucket_part_start_offset % parts_per_bucket_batch == 0 + assert bucket_part_end_offset > bucket_part_start_offset + assert text_part_end_offset > text_part_start_offset + + # Initialize "retry" variables + # + # - retry_count: The number of successive batches that + # we have already performed at a reduced batch size. + # - retry_threshold: The number of successive batches + # for which we should keep the batch size low + # before attempting the default batch size again. + # Every time we return to the default batch size + # and immediately fail, retry_threshold will double. + parts_per_text_batch_retry = None + retry_count, retry_threshold = 0, 1 + + print( + f"Starting at bucket-map partition {bucket_part_start_offset}" + f" and text-df partition {text_part_start_offset}", + flush=True, + ) + + for bucket_part_offset in tqdm( + range(bucket_part_start_offset, bucket_part_end_offset, parts_per_bucket_batch) + ): + + # Outer loop over batches of "bucket-map" partitions + end_bucket_offset = min( + bucket_part_offset + parts_per_bucket_batch, bucket_part_end_offset ) - except OverflowError as err: - # We encountered an overflow error! - # Let's try again with less text data - parts_per_text_batch_retry = int(parts_per_text_batch_use / 2) - if parts_per_text_batch_retry < 1: - raise err print( - f"\nWe encountered an OverflowError and will retry " - f"the current batch with {parts_per_text_batch_retry} " - f"text partitions instead of {parts_per_text_batch_use}.", + f"\nStarted processing bucket-map partitions {bucket_part_offset} " + f"through {end_bucket_offset} of {bucket_part_end_offset}", flush=True, ) - continue - - output_df = output_df.map_partitions(combine_back_adlr_ids) - batch_label = f"{end_bucket_offset}_{end_text_offset}" - written_files = output_df.map_partitions( - write_partitioned_file, - output_path, - partition_on, - batch_label, - meta=cudf.Series([True]), - ) - written_files = written_files.compute() - update_restart_offsets(output_path, bucket_part_offset, end_text_offset) - del output_df - - print( - "Text-df partition ", - f"{end_text_offset}/{text_part_end_offset} " - f"completed in {time.time()-st_text}", - flush=True, - ) - - # Update loop control-flow variables - if parts_per_text_batch_use == parts_per_text_batch: - # We succeeded at the default batch size. - # Reset the retry count - retry_count, retry_threshold = 0, 1 - else: - # We succeeded at a lower batch size - retry_count += 1 - if retry_count >= retry_threshold: - # Go back to the default text-batch size, - # but increase the retry_threshold in - # case we fail again - parts_per_text_batch_retry = None - retry_count, retry_threshold = 0, min(retry_threshold * 2, 16) - text_part_offset += parts_per_text_batch_use - - update_restart_offsets(output_path, end_bucket_offset, end_text_offset) - print( - "Bucket partition ", - f"{end_bucket_offset}/{bucket_part_end_offset} " - f"completed in {time.time()-st_bucket}", - flush=True, - ) + st_bucket = time.time() + + # Select our bucket-mapping batch + subset_bucket_df = right_df.partitions[bucket_part_offset:end_bucket_offset] + subset_bucket_df = subset_bucket_df.persist() + + # Filter out rows of left_df that we know cannot + # align with any rows of subset_bucket_df + left_df_use = filter_text_rows_by_bucket_batch( + left_df, + global_partitioning_index, + bucket_part_offset, + bucket_part_end_offset, + total_bucket_partitions, + ) - # Need to reset text_part_start_offset to 0 after - # a single bucket-batch pass (only matters if we are - # breaking the bucket-mapping df into multiple batches) - text_part_start_offset = 0 + text_part_offset = text_part_start_offset + while text_part_offset < text_part_end_offset: + + # Check if we are "retrying" with a smaller "parts_per_text_batch" + if parts_per_text_batch_retry: + parts_per_text_batch_use = parts_per_text_batch_retry + else: + st_text = time.time() + parts_per_text_batch_use = parts_per_text_batch + print(f"Using {parts_per_text_batch_use} text partitions.", flush=True) + + # Select partitions for our text batch + end_text_offset = min( + text_part_offset + parts_per_text_batch_use, text_part_end_offset + ) + subset_text_df = left_df_use.partitions[text_part_offset:end_text_offset] + + try: + # NOTE: If we have more text-df partitions than bucket-map + # partitions, we are more likely to see an OverflowError + output_df = text_bytes_aware_shuffle( + merge_left_to_shuffled_right( + subset_text_df, + subset_bucket_df, + merge_on, + ), + partition_on, + num_workers=num_workers, + ) + except OverflowError as err: + # We encountered an overflow error! + # Let's try again with less text data + parts_per_text_batch_retry = int(parts_per_text_batch_use / 2) + if parts_per_text_batch_retry < 1: + raise err + print( + f"\nWe encountered an OverflowError and will retry " + f"the current batch with {parts_per_text_batch_retry} " + f"text partitions instead of {parts_per_text_batch_use}.", + flush=True, + ) + continue + + output_df = output_df.map_partitions(combine_back_adlr_ids) + batch_label = f"{end_bucket_offset}_{end_text_offset}" + written_files = output_df.map_partitions( + write_partitioned_file, + output_path, + partition_on, + batch_label, + meta=cudf.Series([True]), + ) + written_files = written_files.compute() + update_restart_offsets(output_path, bucket_part_offset, end_text_offset) + del output_df + + print( + "Text-df partition ", + f"{end_text_offset}/{text_part_end_offset} " + f"completed in {time.time()-st_text}", + flush=True, + ) + + # Update loop control-flow variables + if parts_per_text_batch_use == parts_per_text_batch: + # We succeeded at the default batch size. + # Reset the retry count + retry_count, retry_threshold = 0, 1 + else: + # We succeeded at a lower batch size + retry_count += 1 + if retry_count >= retry_threshold: + # Go back to the default text-batch size, + # but increase the retry_threshold in + # case we fail again + parts_per_text_batch_retry = None + retry_count, retry_threshold = 0, min(retry_threshold * 2, 16) + text_part_offset += parts_per_text_batch_use + + update_restart_offsets(output_path, end_bucket_offset, end_text_offset) + print( + "Bucket partition ", + f"{end_bucket_offset}/{bucket_part_end_offset} " + f"completed in {time.time()-st_bucket}", + flush=True, + ) + + # Need to reset text_part_start_offset to 0 after + # a single bucket-batch pass (only matters if we are + # breaking the bucket-mapping df into multiple batches) + text_part_start_offset = 0 def jaccard_shuffling_workflow( @@ -245,7 +248,7 @@ def jaccard_shuffling_workflow( profile_path, bucket_parts_per_worker, ): - """' + """' Args: client: dask client input_data_paths: paths to input data @@ -259,137 +262,138 @@ def jaccard_shuffling_workflow( profile_path: dask profile path bucket_parts_per_worker: bucket parts per worker to process in a batch """ - # Part1. Reading+Shuffling Data - # Read Text from Data from jsonl files - - text_ddf = get_text_ddf_from_json_path_with_blocksize( - input_data_paths=input_data_paths, - num_files=num_files, - blocksize=text_ddf_blocksize, - ) - print( - "Graph creation for get_text_ddf_from_json_path_with_blocksize" - " complete.", - flush=True) - print(f"text_ddf.npartitions = {text_ddf.npartitions}", flush=True) - st = time.time() - ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read( - input_anchor_docs_with_bk_dir, - blocksize=bucket_mapping_ddf_blocksize, - ) - print("Getting ddf_anchor_docs_with_bk completed") - print( - f"ddf_anchor_docs_with_bk.npartitions = {ddf_anchor_docs_with_bk.npartitions}", - flush=True, - ) - st = time.time() - num_workers = get_num_workers(client) - parts_per_batch = num_workers * parts_per_worker - print(f"parts_per_batch = {parts_per_batch}") - parts_per_bucket_batch = num_workers * bucket_parts_per_worker - print(f"parts_per_bucket_batch = {parts_per_bucket_batch}") - dask_profile_name = f"blocksize-{text_ddf_blocksize}" - dask_profile_name = dask_profile_name + f"parts_per_batch-{parts_per_batch}" - dask_profile_name = (dask_profile_name + - f"-parts_per_bucket_batch-{parts_per_bucket_batch}") - dask_profile_name = dask_profile_name + f"-jaccard-n_input_files-{num_files}.html" - - text_ddf = text_ddf[["dataset_id", "doc_id", "text"]] - - with performance_report_if(profile_path, dask_profile_name): - # Merge and write the dataframes - batched_merge_and_write( - text_ddf, - ddf_anchor_docs_with_bk, - output_path=output_shuffled_docs_path, - merge_on=["dataset_id", "doc_id"], - partition_on="output_partition_id", - parts_per_text_batch=parts_per_batch, - parts_per_bucket_batch=parts_per_bucket_batch, - bk_mapping=bk_mapping, - num_workers=num_workers, + # Part1. Reading+Shuffling Data + # Read Text from Data from jsonl files + + text_ddf = get_text_ddf_from_json_path_with_blocksize( + input_data_paths=input_data_paths, + num_files=num_files, + blocksize=text_ddf_blocksize, + ) + print( + "Graph creation for get_text_ddf_from_json_path_with_blocksize" " complete.", + flush=True, + ) + print(f"text_ddf.npartitions = {text_ddf.npartitions}", flush=True) + st = time.time() + ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read( + input_anchor_docs_with_bk_dir, + blocksize=bucket_mapping_ddf_blocksize, ) - print(f"Writing+Shuffling data took = {time.time()-st} s", flush=True) + print("Getting ddf_anchor_docs_with_bk completed") + print( + f"ddf_anchor_docs_with_bk.npartitions = {ddf_anchor_docs_with_bk.npartitions}", + flush=True, + ) + st = time.time() + num_workers = get_num_workers(client) + parts_per_batch = num_workers * parts_per_worker + print(f"parts_per_batch = {parts_per_batch}") + parts_per_bucket_batch = num_workers * bucket_parts_per_worker + print(f"parts_per_bucket_batch = {parts_per_bucket_batch}") + dask_profile_name = f"blocksize-{text_ddf_blocksize}" + dask_profile_name = dask_profile_name + f"parts_per_batch-{parts_per_batch}" + dask_profile_name = ( + dask_profile_name + f"-parts_per_bucket_batch-{parts_per_bucket_batch}" + ) + dask_profile_name = dask_profile_name + f"-jaccard-n_input_files-{num_files}.html" + + text_ddf = text_ddf[["dataset_id", "doc_id", "text"]] + + with performance_report_if(profile_path, dask_profile_name): + # Merge and write the dataframes + batched_merge_and_write( + text_ddf, + ddf_anchor_docs_with_bk, + output_path=output_shuffled_docs_path, + merge_on=["dataset_id", "doc_id"], + partition_on="output_partition_id", + parts_per_text_batch=parts_per_batch, + parts_per_bucket_batch=parts_per_bucket_batch, + bk_mapping=bk_mapping, + num_workers=num_workers, + ) + print(f"Writing+Shuffling data took = {time.time()-st} s", flush=True) def main(args): - input_data_paths = args.input_data_dirs - input_anchor_docs_with_bk_dir = args.input_bucket_mapping_dir - OUTPUT_PATH = args.output_dir - output_anchor_docs_with_bk_path = os.path.join(OUTPUT_PATH, - "anchor_docs_with_bk.parquet") - output_shuffled_docs_path = os.path.join(OUTPUT_PATH, "shuffled_docs.parquet") - client = get_client(args) - print(f"Num Workers = {get_num_workers(client)}", flush=True) - print("Connected to dask cluster", flush=True) - print("Running jaccard shuffle script", flush=True) - print(f"Args = {args}") - st = time.time() - jaccard_shuffling_workflow( - client=client, - input_data_paths=input_data_paths, - input_anchor_docs_with_bk_dir=input_anchor_docs_with_bk_dir, - output_shuffled_docs_path=output_shuffled_docs_path, - text_ddf_blocksize=args.text_ddf_blocksize, - bucket_mapping_ddf_blocksize=args.bucket_mapping_ddf_blocksize, - num_files=args.num_files, - parts_per_worker=args.parts_per_worker, - profile_path=args.profile_path, - bucket_parts_per_worker=args.bucket_parts_per_worker, - ) - et = time.time() - print(f"Jaccard Shuffle E2E time taken = {et-st} s") + input_data_paths = args.input_data_dirs + input_anchor_docs_with_bk_dir = args.input_bucket_mapping_dir + OUTPUT_PATH = args.output_dir + output_anchor_docs_with_bk_path = os.path.join( + OUTPUT_PATH, "anchor_docs_with_bk.parquet" + ) + output_shuffled_docs_path = os.path.join(OUTPUT_PATH, "shuffled_docs.parquet") + client = get_client(args) + print(f"Num Workers = {get_num_workers(client)}", flush=True) + print("Connected to dask cluster", flush=True) + print("Running jaccard shuffle script", flush=True) + print(f"Args = {args}") + st = time.time() + jaccard_shuffling_workflow( + client=client, + input_data_paths=input_data_paths, + input_anchor_docs_with_bk_dir=input_anchor_docs_with_bk_dir, + output_shuffled_docs_path=output_shuffled_docs_path, + text_ddf_blocksize=args.text_ddf_blocksize, + bucket_mapping_ddf_blocksize=args.bucket_mapping_ddf_blocksize, + num_files=args.num_files, + parts_per_worker=args.parts_per_worker, + profile_path=args.profile_path, + bucket_parts_per_worker=args.bucket_parts_per_worker, + ) + et = time.time() + print(f"Jaccard Shuffle E2E time taken = {et-st} s") def attach_args(parser=None): - description = """Shuffles input text documents based on the given bucket + description = """Shuffles input text documents based on the given bucket map. The output is a partitioned parquet dataset with the documents shuffled by buckets """ - if not parser: - parser = parse_nc_args(description=description) - - parser.add_argument( - "--input-bucket-mapping-dir", - type=str, - help="The directory containing anchor docs with bk files", - ) - parser.add_argument( - "--text-ddf-blocksize", - type=int, - default=256, - help="The block size for chunking jsonl files for text ddf in mb", - ) - parser.add_argument( - "--bucket-mapping-ddf-blocksize", - type=int, - default=256, - help="The block size for for anchor_docs_with_bk ddf in mb", - ) - parser.add_argument( - "--output-dir", - type=str, - help="The output directory to write results in", - ) - parser.add_argument( - "--parts-per-worker", - default=2, - type=int, - help="The number of parts to process per worker per batch", - ) - parser.add_argument( - "--bucket-parts-per-worker", - default=8, - type=int, - help="The number of bucket parts to process per worker per batch", - ) - return parser + if not parser: + parser = parse_nc_args(description=description) + + parser.add_argument( + "--input-bucket-mapping-dir", + type=str, + help="The directory containing anchor docs with bk files", + ) + parser.add_argument( + "--text-ddf-blocksize", + type=int, + default=256, + help="The block size for chunking jsonl files for text ddf in mb", + ) + parser.add_argument( + "--bucket-mapping-ddf-blocksize", + type=int, + default=256, + help="The block size for for anchor_docs_with_bk ddf in mb", + ) + parser.add_argument( + "--output-dir", + type=str, + help="The output directory to write results in", + ) + parser.add_argument( + "--parts-per-worker", + default=2, + type=int, + help="The number of parts to process per worker per batch", + ) + parser.add_argument( + "--bucket-parts-per-worker", + default=8, + type=int, + help="The number of bucket parts to process per worker per batch", + ) + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) if __name__ == "__main__": - main(attach_args().parse_args()) - + main(attach_args().parse_args()) diff --git a/nemo_curator/gpu_deduplication/jaccard_utils/__init__.py b/nemo_curator/gpu_deduplication/jaccard_utils/__init__.py index fe99e99a..d9155f92 100644 --- a/nemo_curator/gpu_deduplication/jaccard_utils/__init__.py +++ b/nemo_curator/gpu_deduplication/jaccard_utils/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/nemo_curator/gpu_deduplication/jaccard_utils/batch_shuffle_utils.py b/nemo_curator/gpu_deduplication/jaccard_utils/batch_shuffle_utils.py index 5bc6c1f7..755112d0 100644 --- a/nemo_curator/gpu_deduplication/jaccard_utils/batch_shuffle_utils.py +++ b/nemo_curator/gpu_deduplication/jaccard_utils/batch_shuffle_utils.py @@ -17,9 +17,7 @@ import numpy as np from dask import config from dask.dataframe.shuffle import rearrange_by_column -from dask_cuda.explicit_comms.dataframe.shuffle import ( - shuffle as explicit_comms_shuffle, -) +from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle from packaging.version import Version from nemo_curator.gpu_deduplication.jaccard_utils.get_output_map_utils import ( diff --git a/nemo_curator/gpu_deduplication/jaccard_utils/doc_id_mapping.py b/nemo_curator/gpu_deduplication/jaccard_utils/doc_id_mapping.py index 27c25f63..e29c626f 100644 --- a/nemo_curator/gpu_deduplication/jaccard_utils/doc_id_mapping.py +++ b/nemo_curator/gpu_deduplication/jaccard_utils/doc_id_mapping.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def convert_str_id_to_int(df, id_column="id"): """ Converts the legacy id format "dataset_name-0000034" diff --git a/nemo_curator/gpu_deduplication/jaccard_utils/get_anchor_utils.py b/nemo_curator/gpu_deduplication/jaccard_utils/get_anchor_utils.py index de8fef50..ea734ded 100644 --- a/nemo_curator/gpu_deduplication/jaccard_utils/get_anchor_utils.py +++ b/nemo_curator/gpu_deduplication/jaccard_utils/get_anchor_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def random_select_anchor(df_bk, n=2): """ Randomly select `n` anchors from each bucket. diff --git a/nemo_curator/gpu_deduplication/jaccard_utils/get_output_map_utils.py b/nemo_curator/gpu_deduplication/jaccard_utils/get_output_map_utils.py index 2c492943..bdbdedc6 100644 --- a/nemo_curator/gpu_deduplication/jaccard_utils/get_output_map_utils.py +++ b/nemo_curator/gpu_deduplication/jaccard_utils/get_output_map_utils.py @@ -16,6 +16,7 @@ import dask_cudf import numba import numpy as np + from nemo_curator._compat import DASK_SHUFFLE_METHOD_ARG diff --git a/nemo_curator/gpu_deduplication/jaccard_utils/io_utils.py b/nemo_curator/gpu_deduplication/jaccard_utils/io_utils.py index 0c1d4895..a24b99dd 100644 --- a/nemo_curator/gpu_deduplication/jaccard_utils/io_utils.py +++ b/nemo_curator/gpu_deduplication/jaccard_utils/io_utils.py @@ -20,10 +20,7 @@ import numpy as np from dask import dataframe as dd -from nemo_curator.gpu_deduplication.ioutils import ( - bucketed_read, - read_json_func, -) +from nemo_curator.gpu_deduplication.ioutils import bucketed_read, read_json_func from nemo_curator.gpu_deduplication.jaccard_utils.doc_id_mapping import ( convert_adlr_id_to_int, ) diff --git a/nemo_curator/gpu_deduplication/prepare_fuzzy_ids.py b/nemo_curator/gpu_deduplication/prepare_fuzzy_ids.py index 23278940..b06601b8 100644 --- a/nemo_curator/gpu_deduplication/prepare_fuzzy_ids.py +++ b/nemo_curator/gpu_deduplication/prepare_fuzzy_ids.py @@ -13,33 +13,34 @@ # limitations under the License. import argparse -import cudf import json -from dask.distributed import Client + +import cudf from dask import dataframe as dd +from dask.distributed import Client def main(args): - # Create the ID mapping - df = cudf.DataFrame() - df['base_id'] = [base_id for base_id in args.base_ids.split(",")] - df['dataset_id'] = df['base_id'].hash_values() - df_pd = df.to_pandas() + # Create the ID mapping + df = cudf.DataFrame() + df["base_id"] = [base_id for base_id in args.base_ids.split(",")] + df["dataset_id"] = df["base_id"].hash_values() + df_pd = df.to_pandas() - output_dict = { - hashed_id: base_id - for base_id, hashed_id in zip(df_pd['base_id'], df_pd['dataset_id']) - } + output_dict = { + hashed_id: base_id + for base_id, hashed_id in zip(df_pd["base_id"], df_pd["dataset_id"]) + } - # Write out the mapping to disk - with open(args.output_id_mapping, 'w') as output_file: - json.dump(output_dict, output_file) + # Write out the mapping to disk + with open(args.output_id_mapping, "w") as output_file: + json.dump(output_dict, output_file) - # Index the parquet files by group - client = Client() - ddf = dd.read_parquet(args.path_to_connected_components) - ddf = ddf.set_index("group") - ddf.to_parquet(args.output_indexed_connected_components) + # Index the parquet files by group + client = Client() + ddf = dd.read_parquet(args.path_to_connected_components) + ddf = ddf.set_index("group") + ddf.to_parquet(args.output_indexed_connected_components) def attach_args( @@ -49,45 +50,46 @@ def attach_args( extraction to .txt and .jsonl files """, formatter_class=argparse.ArgumentDefaultsHelpFormatter, - )): - parser.add_argument( - "--base-ids", - type=str, - default="doc_id", - help="A comma-delimited list of base-ids that were used for " - "different datasets during dedup. For example, " - "if you were deduplicating Wikipedia and Common Crawl, you might " - "have adlr_ids such has wiki-000001 and cc-000001. " - "The base-ids in this case would be 'wiki,cc'", - ) - parser.add_argument( - "--path-to-connected-components", - type=str, - default=None, - help="Path to the connected components that is created " - "at the last step of the fuzzy dedup.", - ) - parser.add_argument( - "--output-indexed-connected-components", - type=str, - default=None, - help="Path to the output connected components " - "that have been prepared for " - "extraction to .txt and .jsonl files", - ) - parser.add_argument( - "--output-id-mapping", - type=str, - default="mapping.json", - help="A mapping between each of the strings specified " - "in '--base-ids' and their respective hashes", - ) - return parser + ) +): + parser.add_argument( + "--base-ids", + type=str, + default="doc_id", + help="A comma-delimited list of base-ids that were used for " + "different datasets during dedup. For example, " + "if you were deduplicating Wikipedia and Common Crawl, you might " + "have adlr_ids such has wiki-000001 and cc-000001. " + "The base-ids in this case would be 'wiki,cc'", + ) + parser.add_argument( + "--path-to-connected-components", + type=str, + default=None, + help="Path to the connected components that is created " + "at the last step of the fuzzy dedup.", + ) + parser.add_argument( + "--output-indexed-connected-components", + type=str, + default=None, + help="Path to the output connected components " + "that have been prepared for " + "extraction to .txt and .jsonl files", + ) + parser.add_argument( + "--output-id-mapping", + type=str, + default="mapping.json", + help="A mapping between each of the strings specified " + "in '--base-ids' and their respective hashes", + ) + return parser if __name__ == "__main__": - main(attach_args().parse_args()) + main(attach_args().parse_args()) def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/gpu_deduplication/utils.py b/nemo_curator/gpu_deduplication/utils.py index 90e7c9e0..ed69477b 100644 --- a/nemo_curator/gpu_deduplication/utils.py +++ b/nemo_curator/gpu_deduplication/utils.py @@ -25,201 +25,207 @@ def create_logger(rank, log_file, name="logger", log_level=logging.INFO): - # Create the logger - logger = logging.getLogger(name) - logger.setLevel(log_level) + # Create the logger + logger = logging.getLogger(name) + logger.setLevel(log_level) - myhost = socket.gethostname() + myhost = socket.gethostname() - extra = {"host": myhost, "rank": rank} - formatter = logging.Formatter( - "%(asctime)s | %(host)s | Rank %(rank)s | %(message)s") + extra = {"host": myhost, "rank": rank} + formatter = logging.Formatter( + "%(asctime)s | %(host)s | Rank %(rank)s | %(message)s" + ) - # File handler for output - file_handler = logging.FileHandler(log_file, mode="a") - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger = logging.LoggerAdapter(logger, extra) + # File handler for output + file_handler = logging.FileHandler(log_file, mode="a") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger = logging.LoggerAdapter(logger, extra) - return logger + return logger -#TODO: Remove below to use nemo_curator.distributed_utils.get_client +# TODO: Remove below to use nemo_curator.distributed_utils.get_client def get_client(args) -> Client: - if args.scheduler_address: - if args.scheduler_file: - raise ValueError( - "Only one of scheduler_address or scheduler_file can be provided") + if args.scheduler_address: + if args.scheduler_file: + raise ValueError( + "Only one of scheduler_address or scheduler_file can be provided" + ) + else: + return Client(address=args.scheduler_address, timeout="30s") + elif args.scheduler_file: + return Client(scheduler_file=args.scheduler_file, timeout="30s") else: - return Client(address=args.scheduler_address, timeout="30s") - elif args.scheduler_file: - return Client(scheduler_file=args.scheduler_file, timeout="30s") - else: - extra_kwargs = ({ - "enable_tcp_over_ucx": True, - "enable_nvlink": True, - "enable_infiniband": False, - "enable_rdmacm": False, - } if args.nvlink_only and args.protocol == "ucx" else {}) - - cluster = LocalCUDACluster( - rmm_pool_size=args.rmm_pool_size, - protocol=args.protocol, - rmm_async=True, - **extra_kwargs, - ) - return Client(cluster) + extra_kwargs = ( + { + "enable_tcp_over_ucx": True, + "enable_nvlink": True, + "enable_infiniband": False, + "enable_rdmacm": False, + } + if args.nvlink_only and args.protocol == "ucx" + else {} + ) + + cluster = LocalCUDACluster( + rmm_pool_size=args.rmm_pool_size, + protocol=args.protocol, + rmm_async=True, + **extra_kwargs, + ) + return Client(cluster) def performance_report_if(path=None, report_name="dask-profile.html"): - if path is not None: - return performance_report(os.path.join(path, report_name)) - else: - return nullcontext() + if path is not None: + return performance_report(os.path.join(path, report_name)) + else: + return nullcontext() -#TODO: Remove below to use nemo_curator.distributed_utils._enable_spilling +# TODO: Remove below to use nemo_curator.distributed_utils._enable_spilling def enable_spilling(): - """ + """ Enables spilling to host memory for cudf """ - cudf.set_option("spill", True) + cudf.set_option("spill", True) def get_num_workers(client): - """ + """ Returns the number of workers in the cluster """ - worker_list = list(client.scheduler_info()["workers"].keys()) - return len(worker_list) + worker_list = list(client.scheduler_info()["workers"].keys()) + return len(worker_list) def get_list_of_lists(lst, nchunks): - """ + """ Splits a list into nchunks lists """ - return [lst[i::nchunks] for i in range(nchunks)] + return [lst[i::nchunks] for i in range(nchunks)] def parse_nc_args( description="Default gpu dedup nemo_curator argument parser", ) -> argparse.ArgumentParser: - """ + """ Adds default set of arguments that are common to multiple stages of the pipeline """ - parser = argparse.ArgumentParser( - description, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--input-data-dirs", - type=str, - nargs="+", - default=None, - required=False, - help="Input directories consisting of .jsonl files that are accessible " - "to all nodes. This path must be accessible by all machines in the cluster", - ) - parser.add_argument( - "--scheduler-address", - type=str, - default=None, - help="Address to the scheduler of a created dask cluster. If not provided" - "a single node LocalCUDACluster will be started.", - ) - parser.add_argument( - "--scheduler-file", - type=str, - default=None, - help="Path to the scheduler file of a created dask cluster. If not provided" - " a single node LocalCUDACluster will be started.", - ) - parser.add_argument( - "--rmm-pool-size", - type=str, - default=None, - help="Initial pool size to use for the RMM Pool Memory allocator" - "Note: This only applies to the localCUDACluster. If providing an user created " - "cluster refer to" - "https://docs.rapids.ai/api/dask-cuda/stable/api.html#cmdoption-dask-cuda-rmm-pool-size", # noqa: E501 - ) - parser.add_argument( - "--protocol", - type=str, - default="tcp", - help="Protcol to use for dask cluster" - "Note: This only applies to the localCUDACluster. If providing an user created " - "cluster refer to" - "https://docs.rapids.ai/api/dask-cuda/stable/api.html#cmdoption-dask-cuda-protocol", # noqa: E501 - ) - parser.add_argument( - "--nvlink-only", - action="store_true", - help="Start a local cluster with only NVLink enabled." - "Only applicable when protocol=ucx and no scheduler file/address is specified", - ) - parser.add_argument( - "--input-json-text-field", - type=str, - default="text", - help="The name of the field within each json object of the jsonl " - "file that contains the text from which minhashes will be computed. ", - ) - parser.add_argument( - "--input-json-id-field", - type=str, - default="adlr_id", - help="The name of the field within each json object of the jsonl " - "file that assigns a unqiue ID to each document. " - "Can be created by running the script " - "'./prospector/add_id.py' which adds the field 'adlr_id' " - "to the documents in a distributed fashion", - ) - parser.add_argument( - "--log-dir", - type=str, - default="./logs/", - help="The output log directory where node and local", - ) - parser.add_argument( - "--files-per-partition", - type=int, - default=2, - help="Number of jsonl files to combine into single partition", - ) - parser.add_argument( - "--num-files", - type=int, - default=None, - help="Upper limit on the number of json files to process", - ) - parser.add_argument( - "--log-frequency", - type=int, - default=500, - help="The frequency with which to write log messages when " - "computing MinHashses. By default a log message will " - "be written every 500 partitions", - ) - parser.add_argument( - "--profile-path", - type=str, - default=None, - help="Path to save dask profile", - ) - return parser + parser = argparse.ArgumentParser( + description, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--input-data-dirs", + type=str, + nargs="+", + default=None, + required=False, + help="Input directories consisting of .jsonl files that are accessible " + "to all nodes. This path must be accessible by all machines in the cluster", + ) + parser.add_argument( + "--scheduler-address", + type=str, + default=None, + help="Address to the scheduler of a created dask cluster. If not provided" + "a single node LocalCUDACluster will be started.", + ) + parser.add_argument( + "--scheduler-file", + type=str, + default=None, + help="Path to the scheduler file of a created dask cluster. If not provided" + " a single node LocalCUDACluster will be started.", + ) + parser.add_argument( + "--rmm-pool-size", + type=str, + default=None, + help="Initial pool size to use for the RMM Pool Memory allocator" + "Note: This only applies to the localCUDACluster. If providing an user created " + "cluster refer to" + "https://docs.rapids.ai/api/dask-cuda/stable/api.html#cmdoption-dask-cuda-rmm-pool-size", # noqa: E501 + ) + parser.add_argument( + "--protocol", + type=str, + default="tcp", + help="Protcol to use for dask cluster" + "Note: This only applies to the localCUDACluster. If providing an user created " + "cluster refer to" + "https://docs.rapids.ai/api/dask-cuda/stable/api.html#cmdoption-dask-cuda-protocol", # noqa: E501 + ) + parser.add_argument( + "--nvlink-only", + action="store_true", + help="Start a local cluster with only NVLink enabled." + "Only applicable when protocol=ucx and no scheduler file/address is specified", + ) + parser.add_argument( + "--input-json-text-field", + type=str, + default="text", + help="The name of the field within each json object of the jsonl " + "file that contains the text from which minhashes will be computed. ", + ) + parser.add_argument( + "--input-json-id-field", + type=str, + default="adlr_id", + help="The name of the field within each json object of the jsonl " + "file that assigns a unqiue ID to each document. " + "Can be created by running the script " + "'./prospector/add_id.py' which adds the field 'adlr_id' " + "to the documents in a distributed fashion", + ) + parser.add_argument( + "--log-dir", + type=str, + default="./logs/", + help="The output log directory where node and local", + ) + parser.add_argument( + "--files-per-partition", + type=int, + default=2, + help="Number of jsonl files to combine into single partition", + ) + parser.add_argument( + "--num-files", + type=int, + default=None, + help="Upper limit on the number of json files to process", + ) + parser.add_argument( + "--log-frequency", + type=int, + default=500, + help="The frequency with which to write log messages when " + "computing MinHashses. By default a log message will " + "be written every 500 partitions", + ) + parser.add_argument( + "--profile-path", + type=str, + default=None, + help="Path to save dask profile", + ) + return parser def timer(func): - def wrapper(*args, **kw): - print(f"function {func.__name__} started...") - start = time() - res = func(*args, **kw) - duration = time() - start - timing = f"function {func.__name__} finished in {duration:.1f} seconds" - print(timing) - return res + def wrapper(*args, **kw): + print(f"function {func.__name__} started...") + start = time() + res = func(*args, **kw) + duration = time() - start + timing = f"function {func.__name__} finished in {duration:.1f} seconds" + print(timing) + return res - return wrapper + return wrapper diff --git a/nemo_curator/gpu_deduplication/verify_all_pairs_jaccard.py b/nemo_curator/gpu_deduplication/verify_all_pairs_jaccard.py index 36aea14d..ae7e6c65 100644 --- a/nemo_curator/gpu_deduplication/verify_all_pairs_jaccard.py +++ b/nemo_curator/gpu_deduplication/verify_all_pairs_jaccard.py @@ -26,141 +26,147 @@ def num_ngram(ds): - return ds.str.character_ngrams(5, True).list.unique().list.len() + return ds.str.character_ngrams(5, True).list.unique().list.len() def write_eligible_pairs(dedup_with_text_path, cache_dir): - df = cudf.read_parquet(dedup_with_text_path) - df["num_ngram"] = num_ngram(df["text"]) - df.drop(columns="text", inplace=True) - df["group"] = 0 - B = 8_000 - rm = 0 - for s in range(0, df.shape[0], B): - e = min(s + B, df.shape[0]) - da = df.iloc[s:e] - db = da.merge(df, on="group") - mask = db["adlr_id_x"] < db["adlr_id_y"] - db = db[mask] - mask = (db["num_ngram_x"] < db["num_ngram_y"] * - 0.8) | (db["num_ngram_y"] < db["num_ngram_x"] * 0.8) - print(db.shape, mask.sum()) - rm += mask.sum() - db = db[~mask] - db.drop(columns=["group", "num_ngram_x", "num_ngram_y"], inplace=True) - db.to_parquet(f"{cache_dir}/pair_{s}.parquet") - del da, db - print("total pairs removed", rm) + df = cudf.read_parquet(dedup_with_text_path) + df["num_ngram"] = num_ngram(df["text"]) + df.drop(columns="text", inplace=True) + df["group"] = 0 + B = 8_000 + rm = 0 + for s in range(0, df.shape[0], B): + e = min(s + B, df.shape[0]) + da = df.iloc[s:e] + db = da.merge(df, on="group") + mask = db["adlr_id_x"] < db["adlr_id_y"] + db = db[mask] + mask = (db["num_ngram_x"] < db["num_ngram_y"] * 0.8) | ( + db["num_ngram_y"] < db["num_ngram_x"] * 0.8 + ) + print(db.shape, mask.sum()) + rm += mask.sum() + db = db[~mask] + db.drop(columns=["group", "num_ngram_x", "num_ngram_y"], inplace=True) + db.to_parquet(f"{cache_dir}/pair_{s}.parquet") + del da, db + print("total pairs removed", rm) def merge_text(df, dedup_with_text_path): - dg = cudf.read_parquet(dedup_with_text_path) - for i in "xy": - df = df.merge(dg, left_on=f"adlr_id_{i}", right_on="adlr_id") - df.drop(columns="adlr_id", inplace=True) - return df + dg = cudf.read_parquet(dedup_with_text_path) + for i in "xy": + df = df.merge(dg, left_on=f"adlr_id_{i}", right_on="adlr_id") + df.drop(columns="adlr_id", inplace=True) + return df def get_max_num_rows_to_process_once(df): - nbytes = max(df["text_x"].str.byte_count().sum(), - df["text_y"].str.byte_count().sum()) - - # TODO: fix below - # to 4x - exploded_bytes = nbytes * 5 * 4 - max_chars_allowed = 2_147_483_647 - byte_ratio = int(exploded_bytes) // max_chars_allowed - if byte_ratio > 1: - nrows_at_once = len(df) // byte_ratio - else: - nrows_at_once = len(df) + nbytes = max( + df["text_x"].str.byte_count().sum(), df["text_y"].str.byte_count().sum() + ) + + # TODO: fix below + # to 4x + exploded_bytes = nbytes * 5 * 4 + max_chars_allowed = 2_147_483_647 + byte_ratio = int(exploded_bytes) // max_chars_allowed + if byte_ratio > 1: + nrows_at_once = len(df) // byte_ratio + else: + nrows_at_once = len(df) - nrows_at_once = max(1, nrows_at_once) - return nrows_at_once + nrows_at_once = max(1, nrows_at_once) + return nrows_at_once def compute_jaccard_pair(docs_df): - nrows_at_once = get_max_num_rows_to_process_once(docs_df) - result_ls = [] - for i in range(0, docs_df.shape[0], nrows_at_once): - pair_df = docs_df[i:i + nrows_at_once] - if len(pair_df) == 0: - result_df = create_empty_jaccard_result() - else: - result_df = compute_jaccard_partition(pair_df) - result_ls.append(result_df) - if len(result_ls) == 0: - return create_empty_jaccard_result() - df_pair = cudf.concat(result_ls) - return df_pair + nrows_at_once = get_max_num_rows_to_process_once(docs_df) + result_ls = [] + for i in range(0, docs_df.shape[0], nrows_at_once): + pair_df = docs_df[i : i + nrows_at_once] + if len(pair_df) == 0: + result_df = create_empty_jaccard_result() + else: + result_df = compute_jaccard_partition(pair_df) + result_ls.append(result_df) + if len(result_ls) == 0: + return create_empty_jaccard_result() + df_pair = cudf.concat(result_ls) + return df_pair def run_verify_all_pairs_jaccard(dedup_with_text_path, cache_dir, output_dir): - ddf = dask_cudf.read_parquet(f"{cache_dir}/pair_*.parquet") - ddf = ddf.repartition(npartitions=2048) - - meta_df = cudf.DataFrame({ - "adlr_id_x": [0], - "adlr_id_y": [0], - "text_x": ["x"], - "text_y": ["x"], - }) - - ddf = ddf.map_partitions(partial(merge_text, - dedup_with_text_path=dedup_with_text_path), - meta=meta_df) - - meta_df = cudf.DataFrame({ - "adlr_id_x": [0], - "adlr_id_y": [0], - "jaccard": [1.0], - }) - - ddf = ddf.map_partitions(compute_jaccard_pair, meta=meta_df) - mask = ddf["jaccard"] > 0.8 - dup_pairs = ddf[mask].compute() - print("# of duplicated pairs with jaccard>0.8", dup_pairs.shape[0]) - dup_pairs.to_parquet(f"{output_dir}/duplicated_pairs.parquet") + ddf = dask_cudf.read_parquet(f"{cache_dir}/pair_*.parquet") + ddf = ddf.repartition(npartitions=2048) + + meta_df = cudf.DataFrame( + { + "adlr_id_x": [0], + "adlr_id_y": [0], + "text_x": ["x"], + "text_y": ["x"], + } + ) + + ddf = ddf.map_partitions( + partial(merge_text, dedup_with_text_path=dedup_with_text_path), meta=meta_df + ) + + meta_df = cudf.DataFrame( + { + "adlr_id_x": [0], + "adlr_id_y": [0], + "jaccard": [1.0], + } + ) + + ddf = ddf.map_partitions(compute_jaccard_pair, meta=meta_df) + mask = ddf["jaccard"] > 0.8 + dup_pairs = ddf[mask].compute() + print("# of duplicated pairs with jaccard>0.8", dup_pairs.shape[0]) + dup_pairs.to_parquet(f"{output_dir}/duplicated_pairs.parquet") def main(args): - start = time() - description = """Verify correctness of deduped results by calculating all pairs""" - dedup_with_text_path = f"{args.output_dir}/dedup_with_text.parquet" + start = time() + description = """Verify correctness of deduped results by calculating all pairs""" + dedup_with_text_path = f"{args.output_dir}/dedup_with_text.parquet" - write_eligible_pairs(dedup_with_text_path, args.cache_dir) - client = get_client(args) + write_eligible_pairs(dedup_with_text_path, args.cache_dir) + client = get_client(args) - # Run actual computation - run_verify_all_pairs_jaccard( - dedup_with_text_path, - args.cache_dir, - args.output_dir, - ) - print(f"All done in {time()-start:.1f} seconds") + # Run actual computation + run_verify_all_pairs_jaccard( + dedup_with_text_path, + args.cache_dir, + args.output_dir, + ) + print(f"All done in {time()-start:.1f} seconds") def attach_args(parser=None): - description = """verify all pairs jaccard""" - if not parser: - parser = parse_nc_args(description=description) - - parser.add_argument( - "--output-dir", - type=str, - help="The output directory to write results to", - ) - parser.add_argument( - "--cache-dir", - type=str, - help="The cache directory to write intermediate results to", - ) - return parser + description = """verify all pairs jaccard""" + if not parser: + parser = parse_nc_args(description=description) + + parser.add_argument( + "--output-dir", + type=str, + help="The output directory to write results to", + ) + parser.add_argument( + "--cache-dir", + type=str, + help="The cache directory to write intermediate results to", + ) + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) if __name__ == "__main__": - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/gpu_deduplication/write_deduped_result_with_text.py b/nemo_curator/gpu_deduplication/write_deduped_result_with_text.py index 851f3bd2..155c56bc 100644 --- a/nemo_curator/gpu_deduplication/write_deduped_result_with_text.py +++ b/nemo_curator/gpu_deduplication/write_deduped_result_with_text.py @@ -17,62 +17,67 @@ import cudf from nemo_curator.gpu_deduplication.jaccard_utils.io_utils import ( - get_text_ddf_from_json_path,) + get_text_ddf_from_json_path, +) from nemo_curator.gpu_deduplication.utils import parse_nc_args def merge_text_partition(df, connected_components_path): - res = cudf.read_parquet(connected_components_path).drop(columns="dataset_id") - res = res.drop_duplicates("group") - res = res.drop(columns=["group"]) - df = res.merge(df, on="doc_id", how="left") - df = df.rename(columns={"doc_id": "adlr_id"}) - return df.drop(columns="dataset_id") + res = cudf.read_parquet(connected_components_path).drop(columns="dataset_id") + res = res.drop_duplicates("group") + res = res.drop(columns=["group"]) + df = res.merge(df, on="doc_id", how="left") + df = df.rename(columns={"doc_id": "adlr_id"}) + return df.drop(columns="dataset_id") def write_result_text_parquet(original_path, output_dir): - ddf = get_text_ddf_from_json_path(original_path, - num_files=-1, - files_per_input_partition=10) + ddf = get_text_ddf_from_json_path( + original_path, num_files=-1, files_per_input_partition=10 + ) - connected_components_path = f"{output_dir}/connected_components.parquet" - print(ddf.head()) - merge_func = partial(merge_text_partition, - connected_components_path=connected_components_path) - ddf = ddf.map_partitions(merge_func, meta={"adlr_id": "uint32", "text": "O"}) + connected_components_path = f"{output_dir}/connected_components.parquet" + print(ddf.head()) + merge_func = partial( + merge_text_partition, connected_components_path=connected_components_path + ) + ddf = ddf.map_partitions(merge_func, meta={"adlr_id": "uint32", "text": "O"}) - mask = ddf.text.isnull() - ddf = ddf[~mask] + mask = ddf.text.isnull() + ddf = ddf[~mask] + + df = ddf.compute() + df = df.reset_index(drop=True) + df.to_parquet(f"{output_dir}/dedup_with_text.parquet") - df = ddf.compute() - df = df.reset_index(drop=True) - df.to_parquet(f"{output_dir}/dedup_with_text.parquet") def main(args): - write_result_text_parquet(original_path=[args.original_path], - output_dir=args.output_dir) + write_result_text_parquet( + original_path=[args.original_path], output_dir=args.output_dir + ) + def attach_args(parser=None): - description = """verify all pairs jaccard""" - if not parser: - parser = parse_nc_args(description=description) - - parser.add_argument( - "--output-dir", - type=str, - help="The output directory to write results to", - ) - parser.add_argument( - "--original-path", - type=str, - help="The path of original jsonl files", - ) - return parser + description = """verify all pairs jaccard""" + if not parser: + parser = parse_nc_args(description=description) + + parser.add_argument( + "--output-dir", + type=str, + help="The output directory to write results to", + ) + parser.add_argument( + "--original-path", + type=str, + help="The path of original jsonl files", + ) + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) if __name__ == "__main__": - args = attach_args().parse_args() + args = attach_args().parse_args() diff --git a/nemo_curator/log.py b/nemo_curator/log.py index 33f1461e..e69afc1f 100644 --- a/nemo_curator/log.py +++ b/nemo_curator/log.py @@ -12,81 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import logging +import os import socket from nemo_curator.utils.file_utils import expand_outdir_and_mkdir -def create_logger(rank, log_file, name='logger', log_level=logging.INFO): - # Create the logger - logger = logging.getLogger(name) - logger.setLevel(log_level) +def create_logger(rank, log_file, name="logger", log_level=logging.INFO): + # Create the logger + logger = logging.getLogger(name) + logger.setLevel(log_level) - myhost = socket.gethostname() + myhost = socket.gethostname() - extra = {'host': myhost, 'rank': rank} - formatter = logging.Formatter( - '%(asctime)s | %(host)s | Rank %(rank)s | %(message)s') + extra = {"host": myhost, "rank": rank} + formatter = logging.Formatter( + "%(asctime)s | %(host)s | Rank %(rank)s | %(message)s" + ) - # File handler for output - file_handler = logging.FileHandler(log_file, mode='a') - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + # File handler for output + file_handler = logging.FileHandler(log_file, mode="a") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) - logger = logging.LoggerAdapter(logger, extra) + logger = logging.LoggerAdapter(logger, extra) - return logger + return logger def create_rank_logger( rank, log_dir, - name='node_logger', + name="node_logger", log_level=logging.INFO, ): - # Make the log directory if it does not exist - log_dir = expand_outdir_and_mkdir(log_dir) + # Make the log directory if it does not exist + log_dir = expand_outdir_and_mkdir(log_dir) - # Create the rank subdirectory - rank_tag = str(rank).rjust(3, '0') - rank_dir = os.path.join(log_dir, f'rank_{rank_tag}') - rank_dir = expand_outdir_and_mkdir(rank_dir) + # Create the rank subdirectory + rank_tag = str(rank).rjust(3, "0") + rank_dir = os.path.join(log_dir, f"rank_{rank_tag}") + rank_dir = expand_outdir_and_mkdir(rank_dir) - log_file = os.path.join(rank_dir, f'rank_{rank_tag}.log') - return create_logger(rank, log_file, name=name, log_level=log_level) + log_file = os.path.join(rank_dir, f"rank_{rank_tag}.log") + return create_logger(rank, log_file, name=name, log_level=log_level) def create_local_logger( rank, local_id, log_dir, - name='local_logger', + name="local_logger", log_level=logging.INFO, ): - # Create the logger - logger = logging.getLogger(name) - logger.setLevel(log_level) + # Create the logger + logger = logging.getLogger(name) + logger.setLevel(log_level) - # Make tags - rank_tag = str(rank).rjust(3, '0') - local_id_tag = str(local_id).rjust(3, '0') + # Make tags + rank_tag = str(rank).rjust(3, "0") + local_id_tag = str(local_id).rjust(3, "0") - myhost = socket.gethostname() - extra = {'host': myhost, 'node': rank_tag, 'local': local_id_tag} - formatter = logging.Formatter('%(asctime)s | %(host)s | Node rank %(node)s ' - '| Local rank %(local)s | %(message)s') + myhost = socket.gethostname() + extra = {"host": myhost, "node": rank_tag, "local": local_id_tag} + formatter = logging.Formatter( + "%(asctime)s | %(host)s | Node rank %(node)s " + "| Local rank %(local)s | %(message)s" + ) - # Output log file - rank_dir = os.path.join(log_dir, f'rank_{rank_tag}') - log_file = os.path.join(rank_dir, f'local_{local_id_tag}.log') + # Output log file + rank_dir = os.path.join(log_dir, f"rank_{rank_tag}") + log_file = os.path.join(rank_dir, f"local_{local_id_tag}.log") - # File handler for output - file_handler = logging.FileHandler(log_file, mode='a') - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + # File handler for output + file_handler = logging.FileHandler(log_file, mode="a") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) - logger = logging.LoggerAdapter(logger, extra) + logger = logging.LoggerAdapter(logger, extra) - return logger + return logger diff --git a/nemo_curator/modifiers/__init__.py b/nemo_curator/modifiers/__init__.py index e6790d23..4c05a31e 100644 --- a/nemo_curator/modifiers/__init__.py +++ b/nemo_curator/modifiers/__init__.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .doc_modifier import DocumentModifier from .c4 import BoilerPlateStringModifier +from .doc_modifier import DocumentModifier from .fasttext import FastTextLabelModifier from .unicode_reformatter import UnicodeReformatter -__all__ = ["DocumentModifier", "BoilerPlateStringModifier", "FastTextLabelModifier", "UnicodeReformatter"] \ No newline at end of file +__all__ = [ + "DocumentModifier", + "BoilerPlateStringModifier", + "FastTextLabelModifier", + "UnicodeReformatter", +] diff --git a/nemo_curator/modifiers/c4.py b/nemo_curator/modifiers/c4.py index 631bc71a..36a3d3f1 100644 --- a/nemo_curator/modifiers/c4.py +++ b/nemo_curator/modifiers/c4.py @@ -12,76 +12,79 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_curator.modifiers import DocumentModifier - +from nemo_curator.modifiers.doc_modifier import DocumentModifier from nemo_curator.utils.constants import policy_substrings from nemo_curator.utils.text_utils import ( get_paragraphs, is_paragraph_indices_in_top_or_bottom_only, ) + class BoilerPlateStringModifier(DocumentModifier): - """ - If the sentence contains any of the boilerplate strings then discard. - This includes things like "terms of use", "privacy policy", etc. - Source: Adapted significantly from Google C4 processing. - """ + """ + If the sentence contains any of the boilerplate strings then discard. + This includes things like "terms of use", "privacy policy", etc. + Source: Adapted significantly from Google C4 processing. + """ - def __init__( - self, - remove_if_at_top_or_bottom=True, - ): - super().__init__() - self._remove_if_at_top_or_bottom = remove_if_at_top_or_bottom - self._top_or_bottom_only = False - self._boilerplate_paragraph_indices = [] - self._name = 'boilerplate_string_ratio' + def __init__( + self, + remove_if_at_top_or_bottom=True, + ): + super().__init__() + self._remove_if_at_top_or_bottom = remove_if_at_top_or_bottom + self._top_or_bottom_only = False + self._boilerplate_paragraph_indices = [] + self._name = "boilerplate_string_ratio" - def modify_document(self, text): - # Initialize variables - self._boilerplate_paragraph_indices = [] + def modify_document(self, text): + # Initialize variables + self._boilerplate_paragraph_indices = [] - # Return an empty string when the document should be removed entirely - empty_string = "" + # Return an empty string when the document should be removed entirely + empty_string = "" - # Get the paragraphs - paragraphs = self._paragraphs - if paragraphs is None: - paragraphs = get_paragraphs(text) + # Get the paragraphs + paragraphs = self._paragraphs + if paragraphs is None: + paragraphs = get_paragraphs(text) - # Check each paragraph - for idx, paragraph in enumerate(paragraphs): - paragraph = paragraph.strip().lower() - if 'lorem ipsum' in paragraph: - return empty_string - if any(p in paragraph for p in policy_substrings): - if not self._remove_if_at_top_or_bottom: - return empty_string - else: - self._boilerplate_paragraph_indices.append(idx) + # Check each paragraph + for idx, paragraph in enumerate(paragraphs): + paragraph = paragraph.strip().lower() + if "lorem ipsum" in paragraph: + return empty_string + if any(p in paragraph for p in policy_substrings): + if not self._remove_if_at_top_or_bottom: + return empty_string + else: + self._boilerplate_paragraph_indices.append(idx) - # Keep the document if we did not find any boilerplate - if len(self._boilerplate_paragraph_indices) == 0: - return text + # Keep the document if we did not find any boilerplate + if len(self._boilerplate_paragraph_indices) == 0: + return text - # Mark if boilerplate is only at top or bottom - self._top_or_bottom_only = is_paragraph_indices_in_top_or_bottom_only( - self._boilerplate_paragraph_indices, - len(paragraphs), - ) + # Mark if boilerplate is only at top or bottom + self._top_or_bottom_only = is_paragraph_indices_in_top_or_bottom_only( + self._boilerplate_paragraph_indices, + len(paragraphs), + ) - if self._top_or_bottom_only: - # In case paragraphs is None, recompute it - if self._paragraphs is None: - self._paragraphs = get_paragraphs(text) - modified_doc = '\n\n'.join([ - p for idx, p in enumerate(self._paragraphs) - if idx not in self._boilerplate_paragraph_indices - ]) - # Set the paragraphs back to None as the document has been - # changed - else: - modified_doc = text + if self._top_or_bottom_only: + # In case paragraphs is None, recompute it + if self._paragraphs is None: + self._paragraphs = get_paragraphs(text) + modified_doc = "\n\n".join( + [ + p + for idx, p in enumerate(self._paragraphs) + if idx not in self._boilerplate_paragraph_indices + ] + ) + # Set the paragraphs back to None as the document has been + # changed + else: + modified_doc = text - self._paragraphs = None - return modified_doc \ No newline at end of file + self._paragraphs = None + return modified_doc diff --git a/nemo_curator/modifiers/doc_modifier.py b/nemo_curator/modifiers/doc_modifier.py index 16b4eddf..1bcde8f5 100644 --- a/nemo_curator/modifiers/doc_modifier.py +++ b/nemo_curator/modifiers/doc_modifier.py @@ -22,7 +22,7 @@ def __init__(self): self._sentences = None self._paragraphs = None self._ngrams = None - + @abstractmethod def modify_document(self, text): - pass \ No newline at end of file + pass diff --git a/nemo_curator/modifiers/fasttext.py b/nemo_curator/modifiers/fasttext.py index 8aa4d1ef..ceae669f 100644 --- a/nemo_curator/modifiers/fasttext.py +++ b/nemo_curator/modifiers/fasttext.py @@ -14,11 +14,12 @@ from nemo_curator.modifiers import DocumentModifier + class FastTextLabelModifier(DocumentModifier): def __init__(self, label): super().__init__() self.label = label def modify_document(self, text): - text = text.replace('\n', ' ').replace('__label__', ' ') - return f"{self.label} {text}" \ No newline at end of file + text = text.replace("\n", " ").replace("__label__", " ") + return f"{self.label} {text}" diff --git a/nemo_curator/modifiers/pii_modifier.py b/nemo_curator/modifiers/pii_modifier.py index 4316f5c1..4a6ef37c 100644 --- a/nemo_curator/modifiers/pii_modifier.py +++ b/nemo_curator/modifiers/pii_modifier.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Dict +from typing import Dict, List import pandas as pd @@ -46,13 +46,15 @@ class PiiModifierBatched(DocumentModifier): """ - def __init__(self, - language: str = DEFAULT_LANGUAGE, - supported_entities: List[str] = None, - anonymize_action: str = "redact", - batch_size: int = DEFAULT_BATCH_SIZE, - device: str = 'gpu', - **kwargs): + def __init__( + self, + language: str = DEFAULT_LANGUAGE, + supported_entities: List[str] = None, + anonymize_action: str = "redact", + batch_size: int = DEFAULT_BATCH_SIZE, + device: str = "gpu", + **kwargs, + ): super().__init__() self.language = language @@ -65,18 +67,22 @@ def __init__(self, def modify_document(self, text: pd.Series, partition_info: Dict = None): import logging - logging.basicConfig(format="%(asctime)s %(levelname)s:%(message)s", level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S") - deidentifier = load_object_on_worker( - "deidentifier", - self.load_deidentifier, - {} + logging.basicConfig( + format="%(asctime)s %(levelname)s:%(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", ) + + deidentifier = load_object_on_worker("deidentifier", self.load_deidentifier, {}) try: - output: List[str] = deidentifier.deidentify_text_batch(text.tolist(), self.batch_size) + output: List[str] = deidentifier.deidentify_text_batch( + text.tolist(), self.batch_size + ) except Exception as e: - logging.error(f"Encountered error {str(e)} in partition {partition_info['number']}") + logging.error( + f"Encountered error {str(e)} in partition {partition_info['number']}" + ) return pd.Series([True]) output: pd.Series = pd.Series(output) return output @@ -86,16 +92,19 @@ def load_deidentifier(self): Helper function to load the de-identifier """ import spacy - if self.device == 'gpu': + + if self.device == "gpu": spacy.require_gpu() - from nemo_curator.pii.algorithm import PiiDeidentifier, DEFAULT_MAX_DOC_SIZE + from nemo_curator.pii.algorithm import DEFAULT_MAX_DOC_SIZE, PiiDeidentifier deidentifier: PiiDeidentifier = PiiDeidentifier( language=self.language, supported_entities=self.supported_entities, anonymize_action=self.anonymize_action, - **self.kwargs + **self.kwargs, + ) + deidentifier.analyzer.nlp_engine.nlp[deidentifier.language].max_length = ( + DEFAULT_MAX_DOC_SIZE ) - deidentifier.analyzer.nlp_engine.nlp[deidentifier.language].max_length = DEFAULT_MAX_DOC_SIZE return deidentifier diff --git a/nemo_curator/modifiers/unicode_reformatter.py b/nemo_curator/modifiers/unicode_reformatter.py index 806934da..707b58ca 100644 --- a/nemo_curator/modifiers/unicode_reformatter.py +++ b/nemo_curator/modifiers/unicode_reformatter.py @@ -13,11 +13,13 @@ # limitations under the License. import ftfy + from nemo_curator.modifiers import DocumentModifier + class UnicodeReformatter(DocumentModifier): def __init__(self): super().__init__() - + def modify_document(self, text): - return ftfy.fix_text(text) \ No newline at end of file + return ftfy.fix_text(text) diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index 01f483ae..2e105a02 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .add_id import AddId from .distributed_data_classifier import DomainClassifier, QualityClassifier from .exact_dedup import ExactDuplicates from .filter import Filter, Score, ScoreFilter @@ -19,7 +20,6 @@ from .meta import Sequential from .modify import Modify from .task import TaskDecontamination -from .add_id import AddId __all__ = [ "DomainClassifier", diff --git a/nemo_curator/modules/add_id.py b/nemo_curator/modules/add_id.py index 168b1836..e8f30739 100644 --- a/nemo_curator/modules/add_id.py +++ b/nemo_curator/modules/add_id.py @@ -13,8 +13,8 @@ # limitations under the License. import dask.dataframe as dd -from dask import delayed import numpy as np +from dask import delayed from nemo_curator.datasets import DocumentDataset @@ -24,7 +24,7 @@ def __init__(self, id_field, id_prefix="doc_id", start_index=0) -> None: self.id_field = id_field self.id_prefix = id_prefix self.start_index = start_index - + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: original_meta = dataset.df.dtypes.to_dict() original_meta[self.id_field] = "object" @@ -33,18 +33,25 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: parition_lengths = [0] for partition in delayed_dataset[:-1]: parition_lengths.append(delayed(len)(partition)) - + lower_id_bounds = delayed(np.cumsum)(parition_lengths) delayed_id_dataset = [] for i, partition in enumerate(delayed_dataset): - delayed_id_dataset.append(delayed(self._add_id_to_partition)(partition, lower_id_bounds[i])) + delayed_id_dataset.append( + delayed(self._add_id_to_partition)(partition, lower_id_bounds[i]) + ) - id_dataset = DocumentDataset(dataset_df=dd.from_delayed(delayed_id_dataset, meta=original_meta)) + id_dataset = DocumentDataset( + dataset_df=dd.from_delayed(delayed_id_dataset, meta=original_meta) + ) return id_dataset - + def _add_id_to_partition(self, partition, partition_start_id): - id_column = [f"{self.id_prefix}-{int(i + self.start_index):010d}" for i in range(partition_start_id, len(partition) + partition_start_id)] + id_column = [ + f"{self.id_prefix}-{int(i + self.start_index):010d}" + for i in range(partition_start_id, len(partition) + partition_start_id) + ] partition[self.id_field] = id_column - return partition \ No newline at end of file + return partition diff --git a/nemo_curator/modules/distributed_data_classifier.py b/nemo_curator/modules/distributed_data_classifier.py index ae65a498..e45d0ba5 100644 --- a/nemo_curator/modules/distributed_data_classifier.py +++ b/nemo_curator/modules/distributed_data_classifier.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod import os -from packaging import version +from abc import ABC, abstractmethod import torch +from packaging import version from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.deberta_v2 import DebertaV2TokenizerFast @@ -34,7 +34,7 @@ class DistributedDataClassifier(ABC): - """ Abstract class for running multi-node multi-GPU data classification """ + """Abstract class for running multi-node multi-GPU data classification""" def __init__( self, @@ -292,7 +292,9 @@ def _inference_per_partition(self, df): else: preds = torch.argmax(probs, dim=1) - df[self.pred_column] = [self.labels[i] for i in preds.to("cpu").numpy().tolist()] + df[self.pred_column] = [ + self.labels[i] for i in preds.to("cpu").numpy().tolist() + ] df[self.prob_column] = probs.to("cpu").numpy().tolist() return df @@ -304,7 +306,9 @@ def _load_cfg_with_tokenizer(self): return cfg def _load_model(self, cfg, device): - model = CustomModel(cfg, out_dim=self.out_dim, config_path=None, pretrained=True) + model = CustomModel( + cfg, out_dim=self.out_dim, config_path=None, pretrained=True + ) model = model.to(device) sd = torch.load(self.model_file_name, map_location="cpu") if "model_state_dict" in sd: diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index 271ba136..5d960ac6 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -23,7 +23,9 @@ from typing import Union import pandas as pd -from dask import dataframe as dd, config +from dask import config +from dask import dataframe as dd + from nemo_curator._compat import DASK_P2P_ERROR from nemo_curator.datasets import DocumentDataset from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if diff --git a/nemo_curator/modules/filter.py b/nemo_curator/modules/filter.py index 563d42b8..4c30de0d 100644 --- a/nemo_curator/modules/filter.py +++ b/nemo_curator/modules/filter.py @@ -12,96 +12,122 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_curator.datasets import DocumentDataset from dask.typing import no_default +from nemo_curator.datasets import DocumentDataset + + class Score: - def __init__(self, score_fn, score_field, text_field="text", batched=False, score_type=None): - """ - Args: - score_fn: The score function that takes in a document string and outputs a score for the document - score_field: The field the score will be stored in. - text_field: The field the documents will be read from - """ - self.score_fn = score_fn - self.score_field = score_field - self.text_field = text_field - self.batched = batched - self.score_type = score_type - - def __call__(self, dataset): - # Set the metadata for the function calls if provided - if self.score_type: - meta = (None, self.score_type) - else: - meta = no_default - - if self.batched: - dataset.df[self.score_field] = dataset.df[self.text_field].map_partitions(self.score_fn, meta=meta) - else: - dataset.df[self.score_field] = dataset.df[self.text_field].apply(self.score_fn, meta=meta) - - return dataset + def __init__( + self, score_fn, score_field, text_field="text", batched=False, score_type=None + ): + """ + Args: + score_fn: The score function that takes in a document string and outputs a score for the document + score_field: The field the score will be stored in. + text_field: The field the documents will be read from + """ + self.score_fn = score_fn + self.score_field = score_field + self.text_field = text_field + self.batched = batched + self.score_type = score_type + + def __call__(self, dataset): + # Set the metadata for the function calls if provided + if self.score_type: + meta = (None, self.score_type) + else: + meta = no_default + + if self.batched: + dataset.df[self.score_field] = dataset.df[self.text_field].map_partitions( + self.score_fn, meta=meta + ) + else: + dataset.df[self.score_field] = dataset.df[self.text_field].apply( + self.score_fn, meta=meta + ) + + return dataset class Filter: - def __init__(self, filter_fn, filter_field, invert=False, batched=False): - """ - Args: - filter_fn: A function that returns True if the document is to be kept - filter_field: The field(s) to be passed into the filter function. - invert: Whether to invert the filter condition - """ - self.filter_fn = filter_fn - self.filter_field = filter_field - self.invert = invert - self.batched = batched - - def __call__(self, dataset): - if self.batched: - bool_mask = dataset.df[self.filter_field].map_partitions(self.filter_fn, meta=(None, bool)) - else: - bool_mask = dataset.df[self.filter_field].apply(self.filter_fn, meta=(None, bool)) - - if self.invert: - bool_mask = ~bool_mask - - return DocumentDataset(dataset.df[bool_mask]) + def __init__(self, filter_fn, filter_field, invert=False, batched=False): + """ + Args: + filter_fn: A function that returns True if the document is to be kept + filter_field: The field(s) to be passed into the filter function. + invert: Whether to invert the filter condition + """ + self.filter_fn = filter_fn + self.filter_field = filter_field + self.invert = invert + self.batched = batched + + def __call__(self, dataset): + if self.batched: + bool_mask = dataset.df[self.filter_field].map_partitions( + self.filter_fn, meta=(None, bool) + ) + else: + bool_mask = dataset.df[self.filter_field].apply( + self.filter_fn, meta=(None, bool) + ) + + if self.invert: + bool_mask = ~bool_mask + + return DocumentDataset(dataset.df[bool_mask]) class ScoreFilter: - def __init__(self, filter_obj, text_field="text", score_field=None, score_type=None, invert=False, batched=False): - """ - Args: - score_field: The field to which the scores will be written. If None, scores will be immediately discarded after use. - """ - self.filter_obj = filter_obj - self.text_field = text_field - self.score_field = score_field - self.score_type = score_type - self.invert = invert - self.batched = batched - - def __call__(self, dataset): - # Set the metadata for the function calls if provided - if self.score_type: - meta = (None, self.score_type) - else: - meta = no_default - - if self.batched: - scores = dataset.df[self.text_field].map_partitions(self.filter_obj.score_document, meta=meta) - else: - scores = dataset.df[self.text_field].apply(self.filter_obj.score_document, meta=meta) - - if self.score_field is not None: - dataset.df[self.score_field] = scores - - if self.batched: - bool_mask = scores.map_partitions(self.filter_obj.keep_document, meta=(None, bool)) - else: - bool_mask = scores.apply(self.filter_obj.keep_document, meta=(None, bool)) - if self.invert: - bool_mask = ~bool_mask - - return DocumentDataset(dataset.df[bool_mask]) + def __init__( + self, + filter_obj, + text_field="text", + score_field=None, + score_type=None, + invert=False, + batched=False, + ): + """ + Args: + score_field: The field to which the scores will be written. If None, scores will be immediately discarded after use. + """ + self.filter_obj = filter_obj + self.text_field = text_field + self.score_field = score_field + self.score_type = score_type + self.invert = invert + self.batched = batched + + def __call__(self, dataset): + # Set the metadata for the function calls if provided + if self.score_type: + meta = (None, self.score_type) + else: + meta = no_default + + if self.batched: + scores = dataset.df[self.text_field].map_partitions( + self.filter_obj.score_document, meta=meta + ) + else: + scores = dataset.df[self.text_field].apply( + self.filter_obj.score_document, meta=meta + ) + + if self.score_field is not None: + dataset.df[self.score_field] = scores + + if self.batched: + bool_mask = scores.map_partitions( + self.filter_obj.keep_document, meta=(None, bool) + ) + else: + bool_mask = scores.apply(self.filter_obj.keep_document, meta=(None, bool)) + if self.invert: + bool_mask = ~bool_mask + + return DocumentDataset(dataset.df[bool_mask]) diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 99a3bb84..3b057605 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -41,7 +41,10 @@ ) from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if from nemo_curator.utils.distributed_utils import get_current_client, get_num_workers -from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import convert_str_id_to_int, int_ids_to_str +from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import ( + convert_str_id_to_int, + int_ids_to_str, +) from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( aggregated_anchor_docs_with_bk_read, get_restart_offsets, diff --git a/nemo_curator/modules/meta.py b/nemo_curator/modules/meta.py index ab5b2774..f0d969ee 100644 --- a/nemo_curator/modules/meta.py +++ b/nemo_curator/modules/meta.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. + class Sequential: def __init__(self, modules): self.modules = modules - + def __call__(self, dataset): for module in self.modules: dataset = module(dataset) - return dataset \ No newline at end of file + return dataset diff --git a/nemo_curator/modules/modify.py b/nemo_curator/modules/modify.py index d02ce9c7..24dd4783 100644 --- a/nemo_curator/modules/modify.py +++ b/nemo_curator/modules/modify.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_curator.modifiers import DocumentModifier -from nemo_curator.datasets import DocumentDataset - from nemo_curator.datasets import DocumentDataset from nemo_curator.modifiers import DocumentModifier @@ -27,8 +24,12 @@ def __init__(self, modifier: DocumentModifier, text_field="text", batched=False) def __call__(self, dataset: DocumentDataset) -> DocumentDataset: if self.batched: - dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions(self.modifier.modify_document, meta=(None, str)) + dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions( + self.modifier.modify_document, meta=(None, str) + ) else: - dataset.df[self.text_field] = dataset.df[self.text_field].apply(self.modifier.modify_document, meta=(None, str)) + dataset.df[self.text_field] = dataset.df[self.text_field].apply( + self.modifier.modify_document, meta=(None, str) + ) return dataset diff --git a/nemo_curator/modules/task.py b/nemo_curator/modules/task.py index f3e0181b..443679c2 100644 --- a/nemo_curator/modules/task.py +++ b/nemo_curator/modules/task.py @@ -12,20 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Union, List from collections import defaultdict -from functools import reduce, partial +from functools import partial, reduce +from typing import Iterable, List, Union -from dask import delayed import dask.dataframe as dd +from dask import delayed -from nemo_curator.tasks.downstream_task import DownstreamTask from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.text_utils import get_words +from nemo_curator.tasks.downstream_task import DownstreamTask from nemo_curator.utils.distributed_utils import single_partition_write_with_filename +from nemo_curator.utils.text_utils import get_words + class TaskDecontamination: - def __init__(self, tasks: Union[DownstreamTask, Iterable[DownstreamTask]], text_field="text", max_ngram_size=13, max_matches=10, min_document_length=200, remove_char_each_side=200, max_splits=10, removed_dir=None) -> None: + def __init__( + self, + tasks: Union[DownstreamTask, Iterable[DownstreamTask]], + text_field="text", + max_ngram_size=13, + max_matches=10, + min_document_length=200, + remove_char_each_side=200, + max_splits=10, + removed_dir=None, + ) -> None: """ Removes segments of downstream evaluation tasks from a dataset Args: @@ -56,11 +67,18 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: # Perform task decontamintation task_ngrams = self.prepare_task_ngram_count() found_result = self._find_matching_ngrams(task_ngrams, delayed_dataset) - matched_ngrams, ngram_freq = found_result['matched-ngrams'], found_result['ngrams-freq'] - delayed_removed_dataset = self._remove_matching_ngrams(matched_ngrams, ngram_freq, delayed_dataset) + matched_ngrams, ngram_freq = ( + found_result["matched-ngrams"], + found_result["ngrams-freq"], + ) + delayed_removed_dataset = self._remove_matching_ngrams( + matched_ngrams, ngram_freq, delayed_dataset + ) # Restore the dataset to its original format - removed_dataset = DocumentDataset(dataset_df=dd.from_delayed(delayed_removed_dataset, meta=original_meta)) + removed_dataset = DocumentDataset( + dataset_df=dd.from_delayed(delayed_removed_dataset, meta=original_meta) + ) return removed_dataset @@ -74,10 +92,12 @@ def prepare_task_ngram_count(self) -> dict: Computes a dictionary of all ngrams in each task as keys and each value set to 0. """ delayed_ngrams = [delayed(task.generate_ngrams)() for task in self.tasks] - aggregated_ngrams = delayed(reduce)(TaskDecontamination._merge_task_ngrams, delayed_ngrams) + aggregated_ngrams = delayed(reduce)( + TaskDecontamination._merge_task_ngrams, delayed_ngrams + ) return aggregated_ngrams - + @staticmethod def _compute_ngram_freq_sorted(task_ngrams): ngrams_freq = defaultdict(int) @@ -94,30 +114,43 @@ def find_matching_ngrams(self, task_ngrams: dict, dataset: DocumentDataset) -> d delayed_dataset = dataset.df.to_delayed() return self._find_matching_ngrams(task_ngrams, delayed_dataset) - + def _find_matching_ngrams(self, task_ngrams: dict, delayed_dataset) -> dict: - task_ngrams_frequency_sorted = delayed(self._compute_ngram_freq_sorted)(task_ngrams) - delayed_counts = [delayed(self._find_ngrams_partition)(partition, task_ngrams, task_ngrams_frequency_sorted) for partition in delayed_dataset] + task_ngrams_frequency_sorted = delayed(self._compute_ngram_freq_sorted)( + task_ngrams + ) + delayed_counts = [ + delayed(self._find_ngrams_partition)( + partition, task_ngrams, task_ngrams_frequency_sorted + ) + for partition in delayed_dataset + ] combined_counts = delayed(reduce)(self._merge_counts, delayed_counts) - formatted_result = delayed(self._format_matching_ngrams_result)(combined_counts, task_ngrams_frequency_sorted) + formatted_result = delayed(self._format_matching_ngrams_result)( + combined_counts, task_ngrams_frequency_sorted + ) return formatted_result - def _find_ngrams_partition(self, dataset_partition, task_ngrams, ngrams_freq_sorted): + def _find_ngrams_partition( + self, dataset_partition, task_ngrams, ngrams_freq_sorted + ): partition_count = defaultdict(int) for document in dataset_partition[self.text_field]: doc_result = self._find_ngrams(document, task_ngrams, ngrams_freq_sorted) - partition_count = TaskDecontamination._merge_counts(partition_count, doc_result) - + partition_count = TaskDecontamination._merge_counts( + partition_count, doc_result + ) + return partition_count - + @staticmethod def _merge_counts(first: dict, second: dict): for ngram, count in second.items(): first[ngram] = first.get(ngram, 0) + count - + return first - + @staticmethod def _format_matching_ngrams_result(matched_ngrams, ngram_freq): return { @@ -143,7 +176,7 @@ def _find_ngrams(self, document, task_ngrams, ngrams_freq_sorted): for i in range(len(words) - self.max_ngram_size + 1): # Check if we found a matching n-gram check_ngram_free = TaskDecontamination._check_text( - words[i:i + self.max_ngram_size], + words[i : i + self.max_ngram_size], task_ngrams, text, positions[i], @@ -162,7 +195,7 @@ def _find_ngrams(self, document, task_ngrams, ngrams_freq_sorted): for ngram_len, _ in ngrams_freq_sorted: # Check if we found a matching n-gram check_ngram_free = TaskDecontamination._check_text( - words[i:i + ngram_len], + words[i : i + ngram_len], task_ngrams, text, positions[i], @@ -185,7 +218,7 @@ def _find_ngrams(self, document, task_ngrams, ngrams_freq_sorted): # check the ending n-gram if ngram_free and len(words) - self.max_ngram_size > 0: # get the last words of the lax max ngram - last_seq_words = words[len(words) - self.max_ngram_size:len(words)] + last_seq_words = words[len(words) - self.max_ngram_size : len(words)] last_seq_start_position = len(words) - self.max_ngram_size # check all n-grams lower than max ngram-len @@ -199,7 +232,7 @@ def _find_ngrams(self, document, task_ngrams, ngrams_freq_sorted): for i in range(len(last_seq_words) - ngram_len + 1): # Check for matching n-grams check_ngram_free = TaskDecontamination._check_text( - last_seq_words[i:i + ngram_len], + last_seq_words[i : i + ngram_len], task_ngrams, text, positions[last_seq_start_position + i], @@ -228,22 +261,35 @@ def _check_text(words, task_ngrams, text, start_position, text_buf, local_ngram) # Count the matched n-gram and consider it later local_ngram[seq] += 1 if (start_position + len(seq) + 1) < len(text): - text_buf.append(text[start_position + len(seq) + 1:len(text)]) + text_buf.append(text[start_position + len(seq) + 1 : len(text)]) return False return True - - def remove_matching_ngrams(self, matched_ngrams: dict, ngram_freq: List[tuple], dataset: DocumentDataset): + + def remove_matching_ngrams( + self, matched_ngrams: dict, ngram_freq: List[tuple], dataset: DocumentDataset + ): original_meta = dataset.df.dtypes.to_dict() delayed_dataset = dataset.df.to_delayed() - delayed_removed_dataset = self._remove_matching_ngrams(matched_ngrams, ngram_freq, delayed_dataset) - removed_dataset = DocumentDataset(dataset_df=dd.from_delayed(delayed_removed_dataset, meta=original_meta)) + delayed_removed_dataset = self._remove_matching_ngrams( + matched_ngrams, ngram_freq, delayed_dataset + ) + removed_dataset = DocumentDataset( + dataset_df=dd.from_delayed(delayed_removed_dataset, meta=original_meta) + ) return removed_dataset - - def _remove_matching_ngrams(self, matched_ngrams: dict, ngram_freq: List[tuple], delayed_dataset): + + def _remove_matching_ngrams( + self, matched_ngrams: dict, ngram_freq: List[tuple], delayed_dataset + ): threshhold_ngrams = delayed(self._threshold_ngram_count)(matched_ngrams) - delayed_removed_dataset = [delayed(self._remove_ngrams_partition)(partition, threshhold_ngrams, ngram_freq) for partition in delayed_dataset] + delayed_removed_dataset = [ + delayed(self._remove_ngrams_partition)( + partition, threshhold_ngrams, ngram_freq + ) + for partition in delayed_dataset + ] return delayed_removed_dataset @@ -252,11 +298,15 @@ def _threshold_ngram_count(self, matched_ngrams: dict) -> set: for ngram, count in matched_ngrams.items(): if count <= self.max_matches: filtered_ngrams.add(ngram) - + return filtered_ngrams def _remove_ngrams_partition(self, partition, task_ngrams, ngrams_freq_sorted): - document_fn = partial(self._remove_ngrams, task_ngrams=task_ngrams, ngrams_freq_sorted=ngrams_freq_sorted) + document_fn = partial( + self._remove_ngrams, + task_ngrams=task_ngrams, + ngrams_freq_sorted=ngrams_freq_sorted, + ) split_text = partition[self.text_field].apply(document_fn) num_splits = split_text.apply(len) @@ -270,7 +320,6 @@ def _remove_ngrams_partition(self, partition, task_ngrams, ngrams_freq_sorted): filtered_partition = partition[valid_documents_mask] return filtered_partition.explode(self.text_field, ignore_index=True) - def _remove_ngrams(self, document, task_ngrams, ngrams_freq_sorted): """ Searches for matching n-grams in a document @@ -288,7 +337,7 @@ def _remove_ngrams(self, document, task_ngrams, ngrams_freq_sorted): for i in range(len(words) - self.max_ngram_size + 1): # Check if we found a matching n-gram check_ngram_free = self._clean_text( - words[i:i + self.max_ngram_size], + words[i : i + self.max_ngram_size], task_ngrams, text, positions[i], @@ -307,7 +356,7 @@ def _remove_ngrams(self, document, task_ngrams, ngrams_freq_sorted): for ngram_len, _ in ngrams_freq_sorted: # Check if we found a matching n-gram check_ngram_free = self._clean_text( - words[i:i + ngram_len], + words[i : i + ngram_len], task_ngrams, text, positions[i], @@ -330,7 +379,7 @@ def _remove_ngrams(self, document, task_ngrams, ngrams_freq_sorted): # check the ending n-gram if ngram_free and len(words) - self.max_ngram_size > 0: # get the last words of the lax max ngram - last_seq_words = words[len(words) - self.max_ngram_size:len(words)] + last_seq_words = words[len(words) - self.max_ngram_size : len(words)] last_seq_start_position = len(words) - self.max_ngram_size # check all n-grams lower than max ngram-len @@ -344,7 +393,7 @@ def _remove_ngrams(self, document, task_ngrams, ngrams_freq_sorted): for i in range(len(last_seq_words) - ngram_len + 1): # Check for matching n-grams check_ngram_free = self._clean_text( - last_seq_words[i:i + ngram_len], + last_seq_words[i : i + ngram_len], task_ngrams, text, positions[last_seq_start_position + i], @@ -373,7 +422,16 @@ def _remove_ngrams(self, document, task_ngrams, ngrams_freq_sorted): return text_buf_ngram_free - def _clean_text(self, words, matched_ngrams, text, start_position, text_buf, text_buf_ngram_free, nosplit_remove=False): + def _clean_text( + self, + words, + matched_ngrams, + text, + start_position, + text_buf, + text_buf_ngram_free, + nosplit_remove=False, + ): seq = " ".join(words) if seq in matched_ngrams: print(" [matched]: {}".format(seq), flush=True) @@ -415,7 +473,7 @@ def _split_text(text, start_pos, remove_char_each_side, seq): while pos > 0 and not text[pos] in punctuations: pos -= 1 if pos > 0: - text_first = text[0:pos + 1] + text_first = text[0 : pos + 1] # add length of seq and remove_char_each_side pos = start_pos + len(seq) + remove_char_each_side @@ -425,6 +483,6 @@ def _split_text(text, start_pos, remove_char_each_side, seq): while pos < len(text) and not text[pos] in punctuations: pos += 1 if pos + 1 < len(text): - text_second = text[pos + 1:len(text)] + text_second = text[pos + 1 : len(text)] - return text_first, text_second \ No newline at end of file + return text_first, text_second diff --git a/nemo_curator/pii/algorithm.py b/nemo_curator/pii/algorithm.py index fd6ed5e8..762214ef 100644 --- a/nemo_curator/pii/algorithm.py +++ b/nemo_curator/pii/algorithm.py @@ -13,42 +13,36 @@ # limitations under the License. from pathlib import Path -from typing import ( - List, - Mapping, - Any, - Union -) +from typing import Any, List, Mapping, Union import yaml -from presidio_analyzer import ( - AnalyzerEngine, - RecognizerRegistry -) +from presidio_analyzer import AnalyzerEngine, RecognizerRegistry from presidio_analyzer.nlp_engine import NerModelConfiguration from presidio_analyzer.nlp_engine.ner_model_configuration import LABELS_TO_IGNORE from presidio_analyzer.predefined_recognizers import ( + CreditCardRecognizer, EmailRecognizer, + IpRecognizer, PhoneRecognizer, SpacyRecognizer, UsSsnRecognizer, - CreditCardRecognizer, - IpRecognizer -) -from presidio_anonymizer import ( - AnonymizerEngine, - BatchAnonymizerEngine ) +from presidio_anonymizer import AnonymizerEngine, BatchAnonymizerEngine from presidio_anonymizer.entities import OperatorConfig from nemo_curator.pii.custom_batch_analyzer_engine import CustomBatchAnalyzerEngine from nemo_curator.pii.custom_nlp_engine import CustomNlpEngine from nemo_curator.pii.recognizers.address_recognizer import AddressRecognizer -__all__ = ['DEFAULT_LANGUAGE', 'SUPPORTED_ENTITIES', 'DEFAULT_MAX_DOC_SIZE', 'PiiDeidentifier'] +__all__ = [ + "DEFAULT_LANGUAGE", + "SUPPORTED_ENTITIES", + "DEFAULT_MAX_DOC_SIZE", + "PiiDeidentifier", +] -DEFAULT_LANGUAGE = 'en' +DEFAULT_LANGUAGE = "en" SUPPORTED_ENTITIES = [ "ADDRESS", "CREDIT_CARD", @@ -70,12 +64,11 @@ class PiiDeidentifier(object): """Cleans PII from an unstructured text""" def __init__( - self, - language: str = DEFAULT_LANGUAGE, - supported_entities: List[str] = None, - anonymize_action: str = "replace", - **kwargs - + self, + language: str = DEFAULT_LANGUAGE, + supported_entities: List[str] = None, + anonymize_action: str = "replace", + **kwargs ): """ Parameters: @@ -100,21 +93,25 @@ def __init__( "PATORG", "HCW", "HOSPITAL", - "FAC" + "FAC", } LABELS_TO_IGNORE.update(additional_labels_to_ignore) - recognizer_registry = RecognizerRegistry(recognizers=[ - EmailRecognizer(), - PhoneRecognizer(), - SpacyRecognizer(), - UsSsnRecognizer(), - CreditCardRecognizer(), - IpRecognizer() - ]) + recognizer_registry = RecognizerRegistry( + recognizers=[ + EmailRecognizer(), + PhoneRecognizer(), + SpacyRecognizer(), + UsSsnRecognizer(), + CreditCardRecognizer(), + IpRecognizer(), + ] + ) self.language = language - ner_model_configuration = NerModelConfiguration(labels_to_ignore=LABELS_TO_IGNORE) + ner_model_configuration = NerModelConfiguration( + labels_to_ignore=LABELS_TO_IGNORE + ) self.analyzer = AnalyzerEngine( registry=recognizer_registry, nlp_engine=CustomNlpEngine(ner_model_configuration=ner_model_configuration), @@ -134,12 +131,15 @@ def __init__( elif anonymize_action == "mask": self.operators["DEFAULT"] = OperatorConfig( - "mask", {"chars_to_mask": kwargs.get("chars_to_mask", 100), - "masking_char": kwargs.get("masking_char", "*"), - "from_end": False} + "mask", + { + "chars_to_mask": kwargs.get("chars_to_mask", 100), + "masking_char": kwargs.get("masking_char", "*"), + "from_end": False, + }, ) - elif anonymize_action == 'lambda': + elif anonymize_action == "lambda": self.operators["DEFAULT"] = OperatorConfig( "custom", {"lambda": kwargs.get("lambda")} ) @@ -149,9 +149,7 @@ def __init__( "replace", {"new_value": kwargs.get("new_value")} ) - self.supported_entities = ( - supported_entities or SUPPORTED_ENTITIES - ) + self.supported_entities = supported_entities or SUPPORTED_ENTITIES if "ADDRESS" in self.supported_entities: self.add_custom_recognizer( @@ -160,13 +158,13 @@ def __init__( @staticmethod def from_config(config: Mapping[str, Any]): - config = config.get('pii_config') - language = config.get('language') - supported_entities = config.get('supported_entities') - operator_config = config.get('anonymize', {}) - operator_name = operator_config.get('action') + config = config.get("pii_config") + language = config.get("language") + supported_entities = config.get("supported_entities") + operator_config = config.get("anonymize", {}) + operator_name = operator_config.get("action") if operator_name: - del operator_config['action'] + del operator_config["action"] return PiiDeidentifier( language=language, @@ -184,7 +182,8 @@ def from_yaml_file(path: Union[Path, str]): def from_default_config(): return PiiDeidentifier( PiiDeidentifier.DEFAULT_LANGUAGE, - supported_entities=SUPPORTED_ENTITIES, anonymize_action='replace' + supported_entities=SUPPORTED_ENTITIES, + anonymize_action="replace", ) def list_supported_entities(self): @@ -204,21 +203,18 @@ def add_custom_operator(self, entity, operator): """Use a custom cleaning operation for a specific entity types""" self.operators[entity] = operator - def analyze_text(self, - text, - entities: List[str] = None, - language: str = 'en' - ): + def analyze_text(self, text, entities: List[str] = None, language: str = "en"): if not entities: entities = self.supported_entities return self.analyzer.analyze(text, language, entities=entities) - def analyze_text_batch(self, - texts: List[str], - entities: List[str] = None, - language: str = 'en', - batch_size: int = 32 - ): + def analyze_text_batch( + self, + texts: List[str], + entities: List[str] = None, + language: str = "en", + batch_size: int = 32, + ): """ For processing batches, use batch analyzer @@ -233,7 +229,9 @@ def analyze_text_batch(self, if not entities: entities = self.supported_entities - return self.batch_analyzer.analyze_iterator(texts, language, entities=entities, batch_size=batch_size) + return self.batch_analyzer.analyze_iterator( + texts, language, entities=entities, batch_size=batch_size + ) def deidentify_text_batch(self, texts: List[str], batch_size: int = 32): """ @@ -248,7 +246,10 @@ def deidentify_text_batch(self, texts: List[str], batch_size: int = 32): List(str): list of deidentified text """ analyzer_results_list = self.batch_analyzer.analyze_iterator( - texts, self.language, entities=self.supported_entities, batch_size=batch_size + texts, + self.language, + entities=self.supported_entities, + batch_size=batch_size, ) anonymized_results_list = self.batch_anonymizer.anonymize_list( @@ -276,10 +277,12 @@ def deidentify_text(self, text: str): if __name__ == "__main__": - txt = "Hello, I am John. I was born on December 5, 1983. My email is john.doe@gmail.com and " \ - "you can call me on (814) 566 4637" + txt = ( + "Hello, I am John. I was born on December 5, 1983. My email is john.doe@gmail.com and " + "you can call me on (814) 566 4637" + ) piid = PiiDeidentifier("en", ["DATE_TIME"]) print(piid.deidentify_text(txt)) - piid = PiiDeidentifier("en", ["ADDRESS", "PERSON"], anonymize_action='replace') + piid = PiiDeidentifier("en", ["ADDRESS", "PERSON"], anonymize_action="replace") print(piid.deidentify_text_batch([txt])) diff --git a/nemo_curator/pii/custom_batch_analyzer_engine.py b/nemo_curator/pii/custom_batch_analyzer_engine.py index a56d4b64..2cdc8411 100644 --- a/nemo_curator/pii/custom_batch_analyzer_engine.py +++ b/nemo_curator/pii/custom_batch_analyzer_engine.py @@ -13,10 +13,15 @@ # limitations under the License. import logging -from typing import List, Iterable, Dict, Union, Any, Optional, Iterator - -from presidio_analyzer import DictAnalyzerResult, RecognizerResult, AnalyzerEngine, BatchAnalyzerEngine, \ - EntityRecognizer +from typing import Any, Dict, Iterable, Iterator, List, Optional, Union + +from presidio_analyzer import ( + AnalyzerEngine, + BatchAnalyzerEngine, + DictAnalyzerResult, + EntityRecognizer, + RecognizerResult, +) from presidio_analyzer.nlp_engine import NlpArtifacts from nemo_curator.pii.custom_nlp_engine import CustomNlpEngine @@ -103,11 +108,14 @@ def analyze_iterator( texts=texts, language=language, batch_size=batch_size, - as_tuples=kwargs.get('as_tuples', False) + as_tuples=kwargs.get("as_tuples", False), ) results = self.analyze_batch( - texts=texts, nlp_artifacts_batch=nlp_artifacts_batch, language=language, **kwargs + texts=texts, + nlp_artifacts_batch=nlp_artifacts_batch, + language=language, + **kwargs, ) return results @@ -177,4 +185,3 @@ def analyze_dict( raise ValueError(f"type {type(value)} is unsupported.") yield DictAnalyzerResult(key=key, value=value, recognizer_results=results) - diff --git a/nemo_curator/pii/custom_nlp_engine.py b/nemo_curator/pii/custom_nlp_engine.py index 705c9f5f..e35f3d6f 100644 --- a/nemo_curator/pii/custom_nlp_engine.py +++ b/nemo_curator/pii/custom_nlp_engine.py @@ -13,10 +13,14 @@ # limitations under the License. import logging -from typing import Optional, List, Dict, Union, Tuple, Iterator +from typing import Dict, Iterator, List, Optional, Tuple, Union import spacy -from presidio_analyzer.nlp_engine import SpacyNlpEngine, NerModelConfiguration, NlpArtifacts +from presidio_analyzer.nlp_engine import ( + NerModelConfiguration, + NlpArtifacts, + SpacyNlpEngine, +) from spacy import Language logger = logging.getLogger("presidio-analyzer") @@ -41,14 +45,16 @@ def load(self) -> None: for model in self.models: self._validate_model_params(model) self._download_spacy_model_if_needed(model["model_name"]) - self.nlp[model["lang_code"]] = spacy.load(model["model_name"], enable=["ner"]) + self.nlp[model["lang_code"]] = spacy.load( + model["model_name"], enable=["ner"] + ) def process_batch( self, texts: Union[List[str], List[Tuple[str, object]]], language: str, as_tuples: bool = False, - batch_size: int = 32 + batch_size: int = 32, ) -> Iterator[Optional[NlpArtifacts]]: """Execute the NLP pipeline on a batch of texts using spacy pipe. @@ -64,8 +70,8 @@ def process_batch( raise ValueError("NLP engine is not loaded. Consider calling .load()") texts = [str(text) for text in texts] - docs = self.nlp[language].pipe(texts, as_tuples=as_tuples, batch_size=batch_size) + docs = self.nlp[language].pipe( + texts, as_tuples=as_tuples, batch_size=batch_size + ) for doc in docs: yield doc.text, self._doc_to_nlp_artifact(doc, language) - - diff --git a/nemo_curator/sample_dataframe.py b/nemo_curator/sample_dataframe.py index d8341dc0..6a316924 100644 --- a/nemo_curator/sample_dataframe.py +++ b/nemo_curator/sample_dataframe.py @@ -62,11 +62,15 @@ def sample_dataframe(df, num_samples): print("Starting sampling workflow", flush=True) st = time.time() df = read_data( - input_files=get_all_files_paths_under(args.input_file_path, recurse_subdirecties=False), + input_files=get_all_files_paths_under( + args.input_file_path, recurse_subdirecties=False + ), file_type=args.input_file_type, add_filename=True, ) - input_files = get_all_files_paths_under(args.input_file_path, recurse_subdirecties=False) + input_files = get_all_files_paths_under( + args.input_file_path, recurse_subdirecties=False + ) sampled_df = sample_dataframe(df, num_samples=args.num_samples) write_to_disk( df=sampled_df, diff --git a/nemo_curator/scripts/__init__.py b/nemo_curator/scripts/__init__.py index fe99e99a..d9155f92 100644 --- a/nemo_curator/scripts/__init__.py +++ b/nemo_curator/scripts/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/nemo_curator/scripts/add_id.py b/nemo_curator/scripts/add_id.py index 433e1d79..4e49663a 100644 --- a/nemo_curator/scripts/add_id.py +++ b/nemo_curator/scripts/add_id.py @@ -13,43 +13,48 @@ # limitations under the License. import argparse - import random + import nemo_curator from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.script_utils import ( - add_distributed_args, - attach_bool_arg, -) +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import ( expand_outdir_and_mkdir, get_all_files_paths_under, ) -from nemo_curator.utils.distributed_utils import ( - get_client, - read_data, - write_to_disk -) +from nemo_curator.utils.script_utils import add_distributed_args, attach_bool_arg def main(args): - client = get_client(args, args.device) + client = get_client(args, args.device) - output_dir = expand_outdir_and_mkdir(args.output_data_dir) - files = get_all_files_paths_under(args.input_data_dir) - if args.shuffle: - random.seed(args.seed) - random.shuffle(files) + output_dir = expand_outdir_and_mkdir(args.output_data_dir) + files = get_all_files_paths_under(args.input_data_dir) + if args.shuffle: + random.seed(args.seed) + random.shuffle(files) - dataset = DocumentDataset(read_data(files, file_type=args.input_file_type, backend="pandas", add_filename=True)) - add_id = nemo_curator.AddId(args.id_field_name, id_prefix=args.id_prefix, start_index=args.starting_index) - id_dataset = add_id(dataset) + dataset = DocumentDataset( + read_data( + files, file_type=args.input_file_type, backend="pandas", add_filename=True + ) + ) + add_id = nemo_curator.AddId( + args.id_field_name, id_prefix=args.id_prefix, start_index=args.starting_index + ) + id_dataset = add_id(dataset) - write_to_disk(id_dataset.df, output_dir, write_to_filename=True, output_type=args.output_file_type) + write_to_disk( + id_dataset.df, + output_dir, + write_to_filename=True, + output_type=args.output_file_type, + ) -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Adds unique identifiers to each document in the dataset. Creates a new ID field with name specified by the argument "--id-field-name" within each json. @@ -61,78 +66,79 @@ def attach_args(parser=argparse.ArgumentParser( If a document identifier does not already exist for each document, then these ids must be added prior to performing fuzzy/exact deduplication """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - help="Input directory consisting of .jsonl files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--starting-index", - type=int, - default=0, - help="Starting index from which to start indexing the documents", - ) - parser.add_argument( - "--output-data-dir", - type=str, - default=None, - help="The output directory to where the jsonl " - "files with ids will be written. If not specified, the ids will " - "be written in-place", - ) - parser.add_argument( - "--id-field-name", - type=str, - default='adlr_id', - help="The name of the field that will contain the id value. " - "Default is 'adlr_id'", - ) - parser.add_argument( - "--id-prefix", - type=str, - default="doc_id", - help="The prefix to the id number that will be assigned to the " - "document. When performing deduplication jointly with different" - "datasets, it is helpful to provide a prefix that denotes that a " - "document belongs to a particular dataset (e.g., wiki for documents" - "that come from the wikipedia dataset)", - ) - attach_bool_arg( - parser, - "shuffle", - help_str="Shuffle the order of files before assigning IDs." - "Useful for creating a copy dataset with different IDs", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="If shuffling is specified, use this random seed to " - "perform the random shuffling", - ) - parser.add_argument( - "--input-file-type", - type=str, - default="jsonl", - help="File type of the dataset to be read in. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--output-file-type", - type=str, - default="jsonl", - help="File type the dataset will be written to. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + help="Input directory consisting of .jsonl files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--starting-index", + type=int, + default=0, + help="Starting index from which to start indexing the documents", + ) + parser.add_argument( + "--output-data-dir", + type=str, + default=None, + help="The output directory to where the jsonl " + "files with ids will be written. If not specified, the ids will " + "be written in-place", + ) + parser.add_argument( + "--id-field-name", + type=str, + default="adlr_id", + help="The name of the field that will contain the id value. " + "Default is 'adlr_id'", + ) + parser.add_argument( + "--id-prefix", + type=str, + default="doc_id", + help="The prefix to the id number that will be assigned to the " + "document. When performing deduplication jointly with different" + "datasets, it is helpful to provide a prefix that denotes that a " + "document belongs to a particular dataset (e.g., wiki for documents" + "that come from the wikipedia dataset)", + ) + attach_bool_arg( + parser, + "shuffle", + help_str="Shuffle the order of files before assigning IDs." + "Useful for creating a copy dataset with different IDs", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="If shuffling is specified, use this random seed to " + "perform the random shuffling", + ) + parser.add_argument( + "--input-file-type", + type=str, + default="jsonl", + help="File type of the dataset to be read in. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--output-file-type", + type=str, + default="jsonl", + help="File type the dataset will be written to. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) - parser = add_distributed_args(parser) + parser = add_distributed_args(parser) - return parser + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/connected_components.py b/nemo_curator/scripts/connected_components.py index dfafe04d..1ab1282a 100644 --- a/nemo_curator/scripts/connected_components.py +++ b/nemo_curator/scripts/connected_components.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import os +import time + +from nemo_curator.gpu_deduplication.utils import enable_spilling, parse_nc_args from nemo_curator.modules.fuzzy_dedup import ConnectedComponents from nemo_curator.utils.distributed_utils import get_client -from nemo_curator.gpu_deduplication.utils import ( - enable_spilling, - parse_nc_args, -) def main(args): diff --git a/nemo_curator/scripts/download_and_extract.py b/nemo_curator/scripts/download_and_extract.py index 38c1b696..4b47a188 100644 --- a/nemo_curator/scripts/download_and_extract.py +++ b/nemo_curator/scripts/download_and_extract.py @@ -14,51 +14,74 @@ import argparse import os + from nemo_curator.download import batch_download, download_and_extract -from nemo_curator.utils.distributed_utils import get_client -from nemo_curator.utils.script_utils import attach_bool_arg, add_distributed_args from nemo_curator.utils.config_utils import build_downloader -from nemo_curator.utils.file_utils import get_all_files_paths_under, expand_outdir_and_mkdir +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.file_utils import ( + expand_outdir_and_mkdir, + get_all_files_paths_under, +) +from nemo_curator.utils.script_utils import add_distributed_args, attach_bool_arg + def read_urls(file_path): - with open(file_path, 'r') as fp: - urls = fp.readlines() - return [url.strip() for url in urls] + with open(file_path, "r") as fp: + urls = fp.readlines() + return [url.strip() for url in urls] + def main(args): - client = get_client(args, args.device) - - if args.input_url_file: - urls = read_urls(args.input_url_file) - outdir = os.path.abspath(os.path.expanduser(args.output_json_dir)) - output_paths = list(map(lambda url: os.path.join(outdir, url.split("/")[-1] + ".jsonl"), urls)) - elif args.input_data_dir: - # If input_data_dir is specified, we operate in extraction only mode. - urls = get_all_files_paths_under(args.input_data_dir) - output_paths = urls - else: - raise ValueError("One of --input-url-file or --input-data-dir must be specified") - - expand_outdir_and_mkdir(args.output_json_dir) - if args.output_download_dir: - raw_download_dir = args.output_download_dir - else: - raw_download_dir = os.path.join(args.output_json_dir, "downloads") - - downloader, iterator, extractor, output_format = build_downloader(args.builder_config_file, default_download_dir=raw_download_dir) - - if args.download_only: - output_paths = batch_download(urls, downloader) - print(f"{len(output_paths)} were downloaded") - return - - dataset = download_and_extract(urls, output_paths, downloader, iterator, extractor, output_format, keep_raw_download=args.keep_downloaded_files, force_download=args.overwrite_existing_json) - - # Sample to trigger the dask computation - sample = dataset.df.sample(frac=10 / len(dataset)).compute() - -def attach_args(parser=argparse.ArgumentParser( - """ + client = get_client(args, args.device) + + if args.input_url_file: + urls = read_urls(args.input_url_file) + outdir = os.path.abspath(os.path.expanduser(args.output_json_dir)) + output_paths = list( + map(lambda url: os.path.join(outdir, url.split("/")[-1] + ".jsonl"), urls) + ) + elif args.input_data_dir: + # If input_data_dir is specified, we operate in extraction only mode. + urls = get_all_files_paths_under(args.input_data_dir) + output_paths = urls + else: + raise ValueError( + "One of --input-url-file or --input-data-dir must be specified" + ) + + expand_outdir_and_mkdir(args.output_json_dir) + if args.output_download_dir: + raw_download_dir = args.output_download_dir + else: + raw_download_dir = os.path.join(args.output_json_dir, "downloads") + + downloader, iterator, extractor, output_format = build_downloader( + args.builder_config_file, default_download_dir=raw_download_dir + ) + + if args.download_only: + output_paths = batch_download(urls, downloader) + print(f"{len(output_paths)} were downloaded") + return + + dataset = download_and_extract( + urls, + output_paths, + downloader, + iterator, + extractor, + output_format, + keep_raw_download=args.keep_downloaded_files, + force_download=args.overwrite_existing_json, + ) + + # Sample to trigger the dask computation + sample = dataset.df.sample(frac=10 / len(dataset)).compute() + + +def attach_args( + parser=argparse.ArgumentParser( + """ Takes an input list of urls and downloads the data and then extracts the text from the downloaded data. Using the --builder-config-file argument, users must provide a YAML file @@ -76,69 +99,70 @@ def attach_args(parser=argparse.ArgumentParser( MPI rank). Additionally, the downloader class should be implemented such that it simply returns the pre-downloaded file """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-url-file", - type=str, - default=None, - help="Input directory consisting of .jsonl files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - required=False, - help="Path to input data directory", - ) - parser.add_argument( - "--output-json-dir", - type=str, - default=None, - help="Output directory to store the extracted text in jsonl files", - ) - attach_bool_arg( - parser, - "download-only", - help_str="Specify this flag if you desire to only download the data" - "files and not extract text from the downloaded files", - ) - parser.add_argument( - "--builder-config-file", - type=str, - default=None, - required=True, - help="YAML file that contains paths to implementations of a downloader, " - "iterator and extractor that will be used in this program " - "to build the documents that make up the output dataset", - ) - attach_bool_arg( - parser, - "keep-downloaded-files", - help_str="If this flag is set to true, the downloaded data files " - "will be kept on disk and not removed after extraction", - ) - parser.add_argument( - "--output-download-dir", - type=str, - default=None, - required=False, - help="The directory to where data files will be written " - "in 'download-only' mode. Specify this argument only when " - "the '--download-only flag is specified'.", - ) - attach_bool_arg( - parser, - "overwrite-existing-json", - help_str="If this flag is specified, then the json data will be " - "overwritten if downloading from the the same file.", - ) - - parser = add_distributed_args(parser) - - return parser + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-url-file", + type=str, + default=None, + help="Input directory consisting of .jsonl files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + required=False, + help="Path to input data directory", + ) + parser.add_argument( + "--output-json-dir", + type=str, + default=None, + help="Output directory to store the extracted text in jsonl files", + ) + attach_bool_arg( + parser, + "download-only", + help_str="Specify this flag if you desire to only download the data" + "files and not extract text from the downloaded files", + ) + parser.add_argument( + "--builder-config-file", + type=str, + default=None, + required=True, + help="YAML file that contains paths to implementations of a downloader, " + "iterator and extractor that will be used in this program " + "to build the documents that make up the output dataset", + ) + attach_bool_arg( + parser, + "keep-downloaded-files", + help_str="If this flag is set to true, the downloaded data files " + "will be kept on disk and not removed after extraction", + ) + parser.add_argument( + "--output-download-dir", + type=str, + default=None, + required=False, + help="The directory to where data files will be written " + "in 'download-only' mode. Specify this argument only when " + "the '--download-only flag is specified'.", + ) + attach_bool_arg( + parser, + "overwrite-existing-json", + help_str="If this flag is specified, then the json data will be " + "overwritten if downloading from the the same file.", + ) + + parser = add_distributed_args(parser) + + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/filter_documents.py b/nemo_curator/scripts/filter_documents.py index e662410f..2d5d12f4 100644 --- a/nemo_curator/scripts/filter_documents.py +++ b/nemo_curator/scripts/filter_documents.py @@ -12,113 +12,155 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse +import os import nemo_curator -from nemo_curator.utils.file_utils import ( - expand_outdir_and_mkdir, - get_batched_files, -) -from nemo_curator.utils.script_utils import attach_bool_arg, add_distributed_args from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.distributed_utils import read_data, write_to_disk, get_client from nemo_curator.utils.config_utils import build_filter_pipeline +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.file_utils import expand_outdir_and_mkdir, get_batched_files +from nemo_curator.utils.script_utils import add_distributed_args, attach_bool_arg def get_dataframe_complement(original_df, filtered_df): - def partition_complement(part_original_df, partition_info=None): - if not partition_info: - return part_original_df - part_filtered_df = filtered_df.get_partition(partition_info["number"]) - complement_mask = ~part_original_df.index.isin(part_filtered_df.index.persist()) - complement_df = part_original_df[complement_mask] - return complement_df + def partition_complement(part_original_df, partition_info=None): + if not partition_info: + return part_original_df + part_filtered_df = filtered_df.get_partition(partition_info["number"]) + complement_mask = ~part_original_df.index.isin(part_filtered_df.index.persist()) + complement_df = part_original_df[complement_mask] + return complement_df + + return original_df.map_partitions(partition_complement) - return original_df.map_partitions(partition_complement) def get_score_fields(pipeline): - score_fields = [] - for nc_module in pipeline.modules: - if isinstance(nc_module, nemo_curator.Score) or isinstance(nc_module, nemo_curator.ScoreFilter): - if nc_module.score_field: - score_fields.append(nc_module.score_field) - - return score_fields + score_fields = [] + for nc_module in pipeline.modules: + if isinstance(nc_module, nemo_curator.Score) or isinstance( + nc_module, nemo_curator.ScoreFilter + ): + if nc_module.score_field: + score_fields.append(nc_module.score_field) + + return score_fields + def write_scores(df, output_dir): - for column in df.columns: - output_path = os.path.join(output_dir, f"{column}.txt") - df[column].to_csv(output_path, single_file=True, encoding="utf-8", header=False, index=False, mode="a") + for column in df.columns: + output_path = os.path.join(output_dir, f"{column}.txt") + df[column].to_csv( + output_path, + single_file=True, + encoding="utf-8", + header=False, + index=False, + mode="a", + ) + def main(args): - client = get_client(args, args.device) - if args.device == "cpu": - backend = "pandas" - elif args.device == "gpu": - backend = "cudf" - else: - raise ValueError(f"Invalid device '{args.device}'. Please specify either 'cpu' or 'gpu'.") - - # Make the output directories - kept_document_dir = args.output_retained_document_dir - removed_document_dir = args.output_removed_document_dir - if kept_document_dir: - expand_outdir_and_mkdir(kept_document_dir) - if removed_document_dir: - expand_outdir_and_mkdir(removed_document_dir) - - filter_pipeline = build_filter_pipeline(args.filter_config_file) - score_fields = get_score_fields(filter_pipeline) + client = get_client(args, args.device) + if args.device == "cpu": + backend = "pandas" + elif args.device == "gpu": + backend = "cudf" + else: + raise ValueError( + f"Invalid device '{args.device}'. Please specify either 'cpu' or 'gpu'." + ) - for files in get_batched_files(args.input_data_dir, kept_document_dir, args.input_file_type, batch_size=args.batch_size): - # Load the data and filter - dataset = DocumentDataset(read_data(files, file_type=args.input_file_type, backend=backend, add_filename=True)) - curr_dataset = prev_dataset = dataset + # Make the output directories + kept_document_dir = args.output_retained_document_dir + removed_document_dir = args.output_removed_document_dir + if kept_document_dir: + expand_outdir_and_mkdir(kept_document_dir) + if removed_document_dir: + expand_outdir_and_mkdir(removed_document_dir) - # Process each filter individually so we can track which documents are removed at each step - for filter_module in filter_pipeline.modules: - curr_dataset = filter_module(curr_dataset).persist() - - filter_field = None - if isinstance(filter_module, nemo_curator.Filter): - filter_field = filter_module.filter_field - elif isinstance(filter_module, nemo_curator.ScoreFilter): - filter_field = filter_module.score_field - - # Save the documents removed by the filter - if removed_document_dir and filter_field: - removed_df = get_dataframe_complement(prev_dataset.df, curr_dataset.df) - removed_filter_dir = os.path.join(removed_document_dir, filter_field) - expand_outdir_and_mkdir(removed_filter_dir) - write_to_disk(removed_df, removed_filter_dir, write_to_filename=True, output_type=args.output_file_type) - prev_dataset = curr_dataset - filtered_dataset = curr_dataset - filtered_dataset = filter_pipeline(dataset).persist() + filter_pipeline = build_filter_pipeline(args.filter_config_file) + score_fields = get_score_fields(filter_pipeline) - # Write scores to separate directory - if args.output_document_score_dir: - if args.id_field in filtered_dataset.df.columns: - output_df = filtered_dataset.df[[args.id_field, *score_fields]] - else: - output_df = filtered_dataset.df[score_fields] - write_scores(output_df, args.output_document_score_dir) + for files in get_batched_files( + args.input_data_dir, + kept_document_dir, + args.input_file_type, + batch_size=args.batch_size, + ): + # Load the data and filter + dataset = DocumentDataset( + read_data( + files, + file_type=args.input_file_type, + backend=backend, + add_filename=True, + ) + ) + curr_dataset = prev_dataset = dataset - # Remove scores if not logged - if not args.log_scores: - filtered_dataset = DocumentDataset(filtered_dataset.df.drop(columns=score_fields)) + # Process each filter individually so we can track which documents are removed at each step + for filter_module in filter_pipeline.modules: + curr_dataset = filter_module(curr_dataset).persist() - # If kept_document_dir is specified, then create it - if kept_document_dir is not None: - write_to_disk(filtered_dataset.df, kept_document_dir, write_to_filename=True, output_type=args.output_file_type) - else: - # Overwrite the existing files - write_to_disk(filtered_dataset.df, args.input_data_dir, write_to_filename=True, output_type=args.output_file_type) - - client.close() + filter_field = None + if isinstance(filter_module, nemo_curator.Filter): + filter_field = filter_module.filter_field + elif isinstance(filter_module, nemo_curator.ScoreFilter): + filter_field = filter_module.score_field -def attach_args(parser=argparse.ArgumentParser( - """ + # Save the documents removed by the filter + if removed_document_dir and filter_field: + removed_df = get_dataframe_complement(prev_dataset.df, curr_dataset.df) + removed_filter_dir = os.path.join(removed_document_dir, filter_field) + expand_outdir_and_mkdir(removed_filter_dir) + write_to_disk( + removed_df, + removed_filter_dir, + write_to_filename=True, + output_type=args.output_file_type, + ) + prev_dataset = curr_dataset + filtered_dataset = curr_dataset + filtered_dataset = filter_pipeline(dataset).persist() + + # Write scores to separate directory + if args.output_document_score_dir: + if args.id_field in filtered_dataset.df.columns: + output_df = filtered_dataset.df[[args.id_field, *score_fields]] + else: + output_df = filtered_dataset.df[score_fields] + write_scores(output_df, args.output_document_score_dir) + + # Remove scores if not logged + if not args.log_scores: + filtered_dataset = DocumentDataset( + filtered_dataset.df.drop(columns=score_fields) + ) + + # If kept_document_dir is specified, then create it + if kept_document_dir is not None: + write_to_disk( + filtered_dataset.df, + kept_document_dir, + write_to_filename=True, + output_type=args.output_file_type, + ) + else: + # Overwrite the existing files + write_to_disk( + filtered_dataset.df, + args.input_data_dir, + write_to_filename=True, + output_type=args.output_file_type, + ) + + client.close() + + +def attach_args( + parser=argparse.ArgumentParser( + """ Main driver script for applying filters to documents distributed across dataset files. Inputs are an input directory consisting of dataset files and a configuration file defining the filter @@ -132,140 +174,142 @@ def attach_args(parser=argparse.ArgumentParser( (and apply to a corpus in distributed fashion), please see the examples directory of this repository """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - help="Input directory consisting of dataset files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--input-local-data-dir", - type=str, - default=None, - help="Input directory consisting of dataset files. " - "Use this argument when a distributed file system is not available.", - ) - parser.add_argument( - "--filter-config-file", - type=str, - required=True, - help="The input filter configuration file that contains the " - "path to the filter module as well as the filter parameters", - ) - parser.add_argument( - "--output-retained-document-dir", - type=str, - default=None, - help="The output directory to where documents that are " - "retained during filtering will be written. If this argument " - "is not specified, then the document scores from the " - "filter(s) will be written to the document meta data in place", - ) - parser.add_argument( - "--output-removed-document-dir", - type=str, - default=None, - help="The output directory to where documents that are removed during " - "filtering will be written. This argument is mainly for quality control " - "in order examine documents that are not preserved during filtering. " - "If it is not specified and the retained-document-dir is specified, " - "then only the retained documents will be written to disk", - ) - attach_bool_arg( - parser, - "filter-only", - default=False, - help_str="Specifying this flag will indicate to the code that only the " - "filtering operation should be performed and that scores should not be " - "computed. This flag should be specified if scores have been " - "pre-computed on the documents (e.g., the code was run without the " - "'--output-retained-document-dir' argument) and users desire to apply " - "the filter using the pre-computed scores", - ) - attach_bool_arg( - parser, - "log-scores", - default=False, - help_str="Specifying this flag will cause the computed scores to be " - "logged as additional keys for each document. This only applies to " - "filters with 'log_score: True' in the config. This can aid in " - "performing an interactive quality check of the documents.", - ) - parser.add_argument( - "--output-document-score-dir", - type=str, - default=None, - help="The output directory to where the computed document scores will " - "be written. For each filter, its score will be written to a separate " - "file where each line of the file corresponds to the score computed " - "for each document in the corpus within this directory. This only applies to " - "filters with 'log_score: True' in the config. If this directory is not " - "specified, then filter scores will not be written", - ) - parser.add_argument( - "--id-field", - type=str, - default='adlr_id', - help="The name of the field within each object of the dataset " - "file that assigns a unqiue ID to each document. " - "If this is specified and found within the object, a list of all " - "ids will be written to the output score directory such that each line" - "is consistent with the lines of the written score files ") - attach_bool_arg( - parser, - "keep-node-scores-tmp-dir", - default=False, - help_str="If multiple nodes are used when computing scores, " - "each node will write out its scores to a temporary directory " - "shared across all nodes. Then, the rank 0 node will " - "concatenate all of the scores creating the output file. " - "By default, this directory is removed after concatenation, " - "however users can keep this temporary directory by specifying " - "the flag --keep-node-scores-tmp-dir ", - ) - parser.add_argument( - "--log-frequency", - type=int, - default=10000, - help="The frequency with which to write log messages when " - "computing scores. By default a log message will " - "be written every 10000 documents in a file", - ) - parser.add_argument( - "--log-dir", - type=str, - default="./log/filter_docs", - help="The output log directory where node and local" - " ranks will write their respective log files", - ) - parser.add_argument( - "--input-file-type", - type=str, - default="jsonl", - help="File type of the dataset to be read in. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--output-file-type", - type=str, - default="jsonl", - help="File type the dataset will be written to. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--batch-size", - type=int, - default=64, - help="Number of files to read into memory at a time.", - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + help="Input directory consisting of dataset files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--input-local-data-dir", + type=str, + default=None, + help="Input directory consisting of dataset files. " + "Use this argument when a distributed file system is not available.", + ) + parser.add_argument( + "--filter-config-file", + type=str, + required=True, + help="The input filter configuration file that contains the " + "path to the filter module as well as the filter parameters", + ) + parser.add_argument( + "--output-retained-document-dir", + type=str, + default=None, + help="The output directory to where documents that are " + "retained during filtering will be written. If this argument " + "is not specified, then the document scores from the " + "filter(s) will be written to the document meta data in place", + ) + parser.add_argument( + "--output-removed-document-dir", + type=str, + default=None, + help="The output directory to where documents that are removed during " + "filtering will be written. This argument is mainly for quality control " + "in order examine documents that are not preserved during filtering. " + "If it is not specified and the retained-document-dir is specified, " + "then only the retained documents will be written to disk", + ) + attach_bool_arg( + parser, + "filter-only", + default=False, + help_str="Specifying this flag will indicate to the code that only the " + "filtering operation should be performed and that scores should not be " + "computed. This flag should be specified if scores have been " + "pre-computed on the documents (e.g., the code was run without the " + "'--output-retained-document-dir' argument) and users desire to apply " + "the filter using the pre-computed scores", + ) + attach_bool_arg( + parser, + "log-scores", + default=False, + help_str="Specifying this flag will cause the computed scores to be " + "logged as additional keys for each document. This only applies to " + "filters with 'log_score: True' in the config. This can aid in " + "performing an interactive quality check of the documents.", + ) + parser.add_argument( + "--output-document-score-dir", + type=str, + default=None, + help="The output directory to where the computed document scores will " + "be written. For each filter, its score will be written to a separate " + "file where each line of the file corresponds to the score computed " + "for each document in the corpus within this directory. This only applies to " + "filters with 'log_score: True' in the config. If this directory is not " + "specified, then filter scores will not be written", + ) + parser.add_argument( + "--id-field", + type=str, + default="adlr_id", + help="The name of the field within each object of the dataset " + "file that assigns a unqiue ID to each document. " + "If this is specified and found within the object, a list of all " + "ids will be written to the output score directory such that each line" + "is consistent with the lines of the written score files ", + ) + attach_bool_arg( + parser, + "keep-node-scores-tmp-dir", + default=False, + help_str="If multiple nodes are used when computing scores, " + "each node will write out its scores to a temporary directory " + "shared across all nodes. Then, the rank 0 node will " + "concatenate all of the scores creating the output file. " + "By default, this directory is removed after concatenation, " + "however users can keep this temporary directory by specifying " + "the flag --keep-node-scores-tmp-dir ", + ) + parser.add_argument( + "--log-frequency", + type=int, + default=10000, + help="The frequency with which to write log messages when " + "computing scores. By default a log message will " + "be written every 10000 documents in a file", + ) + parser.add_argument( + "--log-dir", + type=str, + default="./log/filter_docs", + help="The output log directory where node and local" + " ranks will write their respective log files", + ) + parser.add_argument( + "--input-file-type", + type=str, + default="jsonl", + help="File type of the dataset to be read in. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--output-file-type", + type=str, + default="jsonl", + help="File type the dataset will be written to. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Number of files to read into memory at a time.", + ) + + parser = add_distributed_args(parser) - parser = add_distributed_args(parser) - - return parser + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/find_exact_duplicates.py b/nemo_curator/scripts/find_exact_duplicates.py index a7c84ce2..7da01ea8 100644 --- a/nemo_curator/scripts/find_exact_duplicates.py +++ b/nemo_curator/scripts/find_exact_duplicates.py @@ -16,100 +16,102 @@ import time import dask_cudf + from nemo_curator.datasets import DocumentDataset -from nemo_curator.modules import ExactDuplicates from nemo_curator.gpu_deduplication.ioutils import strip_trailing_sep -from nemo_curator.gpu_deduplication.utils import (create_logger, parse_nc_args) +from nemo_curator.gpu_deduplication.utils import create_logger, parse_nc_args +from nemo_curator.modules import ExactDuplicates +from nemo_curator.utils.distributed_utils import get_client, read_data from nemo_curator.utils.file_utils import get_all_files_paths_under -from nemo_curator.utils.distributed_utils import read_data, get_client def pre_imports(): - import cudf # noqa: F401 + import cudf # noqa: F401 def main(args): - logger = create_logger(rank=0, - log_file=os.path.join(args.log_dir, "rank_000.log"), - name="exact_dedup") - logger.info(f"Starting workflow with args:\n {args}") - - assert args.hash_method == "md5", "Currently only md5 hash is supported" - args.set_torch_to_use_rmm = False - client = get_client(args, cluster_type="cpu" if args.no_gpu else "gpu") - logger.info(f"Client Created {client}") - if not args.no_gpu: - client.run(pre_imports) - logger.info("Pre imports complete") - - data_paths = args.input_data_dirs - id_field = args.input_json_id_field - text_field = args.input_json_text_field - num_files = args.num_files - t0 = time.time() - dfs = [] - for data_path in data_paths: - data_path = strip_trailing_sep(data_path) - if num_files is not None and num_files <= 0: - logger.info(f"Processed {num_files}... quitting") - break - files = get_all_files_paths_under(root=data_path, - recurse_subdirectories=False) - files = [f for f in files if f.endswith(".jsonl")] - df = read_data( - files[:num_files] if num_files else files, - file_type="jsonl", - backend="pandas" if args.no_gpu else "cudf", - files_per_partition=args.files_per_partition, - add_filename=False, - )[[id_field, text_field]] - if num_files is not None: - num_files -= len(files) - dfs.append(df) - logger.info(f"Lazy read complete for {dfs[-1].npartitions} partitions") - - input_df = dask_cudf.concat(dfs, ignore_unknown_divisions=True) - exact_dups = ExactDuplicates(logger=logger, - id_field=id_field, - text_field=text_field, - hash_method=args.hash_method, - profile_dir=args.profile_path, - cache_dir=args.output_dir) - exact_dups(dataset=DocumentDataset(input_df)) - logger.info( - f"Exact dedup computation across datasets took {time.time() - t0}s complete at {args.output_dir}" # noqa:E501 - ) + logger = create_logger( + rank=0, log_file=os.path.join(args.log_dir, "rank_000.log"), name="exact_dedup" + ) + logger.info(f"Starting workflow with args:\n {args}") + + assert args.hash_method == "md5", "Currently only md5 hash is supported" + args.set_torch_to_use_rmm = False + client = get_client(args, cluster_type="cpu" if args.no_gpu else "gpu") + logger.info(f"Client Created {client}") + if not args.no_gpu: + client.run(pre_imports) + logger.info("Pre imports complete") + + data_paths = args.input_data_dirs + id_field = args.input_json_id_field + text_field = args.input_json_text_field + num_files = args.num_files + t0 = time.time() + dfs = [] + for data_path in data_paths: + data_path = strip_trailing_sep(data_path) + if num_files is not None and num_files <= 0: + logger.info(f"Processed {num_files}... quitting") + break + files = get_all_files_paths_under(root=data_path, recurse_subdirectories=False) + files = [f for f in files if f.endswith(".jsonl")] + df = read_data( + files[:num_files] if num_files else files, + file_type="jsonl", + backend="pandas" if args.no_gpu else "cudf", + files_per_partition=args.files_per_partition, + add_filename=False, + )[[id_field, text_field]] + if num_files is not None: + num_files -= len(files) + dfs.append(df) + logger.info(f"Lazy read complete for {dfs[-1].npartitions} partitions") + + input_df = dask_cudf.concat(dfs, ignore_unknown_divisions=True) + exact_dups = ExactDuplicates( + logger=logger, + id_field=id_field, + text_field=text_field, + hash_method=args.hash_method, + profile_dir=args.profile_path, + cache_dir=args.output_dir, + ) + exact_dups(dataset=DocumentDataset(input_df)) + logger.info( + f"Exact dedup computation across datasets took {time.time() - t0}s complete at {args.output_dir}" # noqa:E501 + ) def attach_args(parser=None): - description = """Compute Exact duplicates in a given dataset. + description = """Compute Exact duplicates in a given dataset. """ - if not parser: - parser = parse_nc_args(description=description) - parser.add_argument( - "--hash-method", - type=str, - default="md5", - help="Hash Method to use for exact dedup", - ) - parser.add_argument( - "--output-dir", - type=str, - required=True, - help="Output directory where duplicate docs will be written. " - "Each file is a pickle file that contains a dictionary of numpy arrays. " - "The keys are the document ids and the values are the duplicate docs", - ) - parser.add_argument("--no-gpu", - action="store_true", - help="Use CPU based exact dedup") - - return parser + if not parser: + parser = parse_nc_args(description=description) + parser.add_argument( + "--hash-method", + type=str, + default="md5", + help="Hash Method to use for exact dedup", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory where duplicate docs will be written. " + "Each file is a pickle file that contains a dictionary of numpy arrays. " + "The keys are the document ids and the values are the duplicate docs", + ) + parser.add_argument( + "--no-gpu", action="store_true", help="Use CPU based exact dedup" + ) + + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) if __name__ == "__main__": - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/find_matching_ngrams.py b/nemo_curator/scripts/find_matching_ngrams.py index 76a3e4a5..e8325893 100644 --- a/nemo_curator/scripts/find_matching_ngrams.py +++ b/nemo_curator/scripts/find_matching_ngrams.py @@ -17,96 +17,102 @@ import nemo_curator from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.distributed_utils import get_client, read_data from nemo_curator.utils.file_utils import get_all_files_paths_under from nemo_curator.utils.script_utils import add_distributed_args -from nemo_curator.utils.distributed_utils import get_client, read_data def main(args): - client = get_client(args, args.device) + client = get_client(args, args.device) - # Each rank read in the task data - with open(args.input_task_ngrams, 'rb') as fp: - task_ngrams = pickle.load(fp) + # Each rank read in the task data + with open(args.input_task_ngrams, "rb") as fp: + task_ngrams = pickle.load(fp) - decontaminator = nemo_curator.TaskDecontamination([], text_field=args.input_text_field, max_ngram_size=args.max_ngram_size) + decontaminator = nemo_curator.TaskDecontamination( + [], text_field=args.input_text_field, max_ngram_size=args.max_ngram_size + ) - files = get_all_files_paths_under(args.input_data_dir) - dataset = DocumentDataset(read_data(files, file_type=args.input_file_type, backend="pandas")) + files = get_all_files_paths_under(args.input_data_dir) + dataset = DocumentDataset( + read_data(files, file_type=args.input_file_type, backend="pandas") + ) - result = decontaminator.find_matching_ngrams(task_ngrams, dataset).compute() - print(f"Found a total of {len(result['matched-ngrams'])} matching n-grams") + result = decontaminator.find_matching_ngrams(task_ngrams, dataset).compute() + print(f"Found a total of {len(result['matched-ngrams'])} matching n-grams") - output = { - 'matched-ngrams': result['matched-ngrams'], - 'ngrams-freq': result['ngrams-freq'], - 'max-ngram-size': args.max_ngram_size, - 'min-ngram-size': args.min_ngram_size, - } - with open(args.output_matched_ngram_data, 'wb') as fp: - pickle.dump(output, fp) + output = { + "matched-ngrams": result["matched-ngrams"], + "ngrams-freq": result["ngrams-freq"], + "max-ngram-size": args.max_ngram_size, + "min-ngram-size": args.min_ngram_size, + } + with open(args.output_matched_ngram_data, "wb") as fp: + pickle.dump(output, fp) -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Searches for matching task n-grams in the input dataset and writes out a list of n-grams that were found. """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - help="Input directory consisting of .jsonl files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--input-text-field", - type=str, - default='text', - help="The name of the field within each datapoint object of the input " - "file that contains the text.", - ) - parser.add_argument( - "--output-matched-ngram-data", - type=str, - default=None, - help="Output dictionary that contains the output matched n-grams " - "and the frequency of their matches, min-ngram size, max-ngram " - "size and the frequencies of n-gram sizes. All of these data will be " - "used by remove_matching_grams for which this program is a prequisite", - ) - parser.add_argument( - "--input-task-ngrams", - type=str, - default=None, - help="", - ) - parser.add_argument( - "--max-ngram-size", - type=int, - default=13, - help="The maximum n-gram size to consider within the dataset", - ) - parser.add_argument( - "--min-ngram-size", - type=int, - default=8, - help="The minimum n-gram size to consider within the datset", - ) - parser.add_argument( - "--input-file-type", - type=str, - default="jsonl", - help="File type of the dataset to be read in. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + help="Input directory consisting of .jsonl files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--input-text-field", + type=str, + default="text", + help="The name of the field within each datapoint object of the input " + "file that contains the text.", + ) + parser.add_argument( + "--output-matched-ngram-data", + type=str, + default=None, + help="Output dictionary that contains the output matched n-grams " + "and the frequency of their matches, min-ngram size, max-ngram " + "size and the frequencies of n-gram sizes. All of these data will be " + "used by remove_matching_grams for which this program is a prequisite", + ) + parser.add_argument( + "--input-task-ngrams", + type=str, + default=None, + help="", + ) + parser.add_argument( + "--max-ngram-size", + type=int, + default=13, + help="The maximum n-gram size to consider within the dataset", + ) + parser.add_argument( + "--min-ngram-size", + type=int, + default=8, + help="The minimum n-gram size to consider within the datset", + ) + parser.add_argument( + "--input-file-type", + type=str, + default="jsonl", + help="File type of the dataset to be read in. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) - parser = add_distributed_args(parser) + parser = add_distributed_args(parser) - return parser + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/find_pii_and_deidentify.py b/nemo_curator/scripts/find_pii_and_deidentify.py index 7682849f..d55a1117 100644 --- a/nemo_curator/scripts/find_pii_and_deidentify.py +++ b/nemo_curator/scripts/find_pii_and_deidentify.py @@ -17,25 +17,31 @@ import time from pathlib import Path -from nemo_curator.modules.modify import Modify from nemo_curator.datasets import DocumentDataset from nemo_curator.modifiers.pii_modifier import PiiModifierBatched +from nemo_curator.modules.modify import Modify + # from nemo_curator.pii.algorithm import DEFAULT_LANGUAGE -from nemo_curator.utils.distributed_utils import read_data, write_to_disk, get_client +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import get_batched_files from nemo_curator.utils.script_utils import add_distributed_args def main(args): """Main function that performs PII de-identifcation given a batch of files""" - logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.DEBUG, - datefmt="%Y-%m-%d %H:%M:%S") + logging.basicConfig( + format="%(asctime)s %(levelname)s:%(message)s", + level=logging.DEBUG, + datefmt="%Y-%m-%d %H:%M:%S", + ) logging.debug("Beginning PII job") start_time = time.time() Path(args.output_data_dir).mkdir(parents=True, exist_ok=True) - supported_entities = args.supported_entities.split(',') if args.supported_entities else None + supported_entities = ( + args.supported_entities.split(",") if args.supported_entities else None + ) modifier = PiiModifierBatched( language=args.language, @@ -46,35 +52,41 @@ def main(args): masking_char=args.masking_char, new_value=args.new_value, batch_size=args.batch_size, - device=args.device) + device=args.device, + ) for file_names in get_batched_files( - args.input_data_dir, - args.output_data_dir, - args.input_file_type, - args.n_workers + args.input_data_dir, args.output_data_dir, args.input_file_type, args.n_workers ): logging.info("Reading input files....") - source_data = read_data(file_names, file_type=args.input_file_type, backend='pandas', add_filename=True) + source_data = read_data( + file_names, + file_type=args.input_file_type, + backend="pandas", + add_filename=True, + ) dataset = DocumentDataset(source_data) logging.debug(f"Dataset has {source_data.npartitions} partitions") modify = Modify(modifier, batched=True) modified_dataset = modify(dataset) - write_to_disk(modified_dataset.df, - args.output_data_dir, - write_to_filename=True, - output_type=args.output_file_type - ) + write_to_disk( + modified_dataset.df, + args.output_data_dir, + write_to_filename=True, + output_type=args.output_file_type, + ) end_time = time.time() - logging.debug("Total time taken in PII job: %0.3f seconds" % (end_time - start_time)) + logging.debug( + "Total time taken in PII job: %0.3f seconds" % (end_time - start_time) + ) def attach_args( parser=argparse.ArgumentParser( """ - Main driver script for applying PII redaction on documents. Inputs are in the input-data-dir directory. + Main driver script for applying PII redaction on documents. Inputs are in the input-data-dir directory. This script will then perform PII detection and de-identification on each document within the corpus. """, formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -84,9 +96,9 @@ def attach_args( parser.add_argument( "--language", type=str, - default='en', + default="en", required=False, - help="Language of input documents" + help="Language of input documents", ) parser.add_argument( @@ -94,15 +106,15 @@ def attach_args( type=str, default=None, required=False, - help="Comma separated list of PII entity types. None implies all supported types" + help="Comma separated list of PII entity types. None implies all supported types", ) parser.add_argument( "--anonymize-action", type=str, - default='replace', + default="replace", required=False, - help="Anonymization action. Choose from among: redact, hash, mask and replace" + help="Anonymization action. Choose from among: redact, hash, mask and replace", ) parser.add_argument( @@ -110,7 +122,7 @@ def attach_args( type=str, default=None, required=False, - help="The hash type. Choose from among: sha256, sha512 or md5" + help="The hash type. Choose from among: sha256, sha512 or md5", ) parser.add_argument( @@ -118,15 +130,15 @@ def attach_args( type=int, default=100, required=False, - help="The number of characters to mask. Only applicable if anonymize action is mask" + help="The number of characters to mask. Only applicable if anonymize action is mask", ) parser.add_argument( "--masking-char", type=str, - default='*', + default="*", required=False, - help="The masking character. Only applicable if anonymize action is mask" + help="The masking character. Only applicable if anonymize action is mask", ) parser.add_argument( @@ -134,7 +146,7 @@ def attach_args( type=str, default=None, required=False, - help="The new value to replace with. Only applicable if anonymize action is replace" + help="The new value to replace with. Only applicable if anonymize action is replace", ) parser.add_argument( @@ -142,24 +154,24 @@ def attach_args( type=str, default=None, required=True, - help="Directory containing the input files" + help="Directory containing the input files", ) parser.add_argument( "--input-file-type", type=str, - default='jsonl', + default="jsonl", required=True, - choices=['jsonl', 'csv', 'text'], + choices=["jsonl", "csv", "text"], help="The input file type (only jsonl is currently supported)", ) parser.add_argument( "--output-file-type", type=str, - default='jsonl', + default="jsonl", required=True, - choices=['jsonl', 'csv', 'text'], + choices=["jsonl", "csv", "text"], help="The output file type (only jsonl is currently supported)", ) @@ -194,9 +206,9 @@ def console_script(): arguments = add_distributed_args(attach_args()).parse_args() client = get_client(arguments, arguments.device) if not arguments.n_workers: - arguments.n_workers = len(client.scheduler_info()['workers']) + arguments.n_workers = len(client.scheduler_info()["workers"]) main(arguments) if __name__ == "__main__": - console_script() \ No newline at end of file + console_script() diff --git a/nemo_curator/scripts/get_common_crawl_urls.py b/nemo_curator/scripts/get_common_crawl_urls.py index aaf5b728..1e9c6455 100644 --- a/nemo_curator/scripts/get_common_crawl_urls.py +++ b/nemo_curator/scripts/get_common_crawl_urls.py @@ -14,78 +14,87 @@ import argparse -from nemo_curator.utils.script_utils import attach_bool_arg from nemo_curator.utils.download_utils import get_common_crawl_urls +from nemo_curator.utils.script_utils import attach_bool_arg + def main(args): - urls = get_common_crawl_urls(args.starting_snapshot, args.ending_snapshot, data_domain_prefix=args.cc_data_domain_prefix, index_prefix=args.cc_index_prefix, news=args.cc_news) - - with open(args.output_warc_url_file, "w") as fp: - for url in urls: - fp.write(url) - fp.write('\n') + urls = get_common_crawl_urls( + args.starting_snapshot, + args.ending_snapshot, + data_domain_prefix=args.cc_data_domain_prefix, + index_prefix=args.cc_index_prefix, + news=args.cc_news, + ) + + with open(args.output_warc_url_file, "w") as fp: + for url in urls: + fp.write(url) + fp.write("\n") -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Pulls URLs of WARC files stored within the common crawl data repository and writes them to file so that they can be used to subsequently download the WARC files. """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--cc-data-domain-prefix", - type=str, - default="https://data.commoncrawl.org", - help="The prefix that will be prepended to each WARC " - "file to create the URL. By default this value is " - " 'https://data.commoncrawl.org'", - ) - parser.add_argument( - "--cc-index-prefix", - type=str, - default="https://index.commoncrawl.org", - help="The prefix of the URL to the Common Crawl index. " - "By default this value is 'https://index.commoncrawl.org'", - ) - parser.add_argument( - "--output-warc-url-file", - type=str, - default=None, - required=True, - help="The output file to which the WARC urls will be written", - ) - parser.add_argument( - "--starting-snapshot", - type=str, - default="2020-50", - help="The starting snapshot to download. All WARC urls will be written " - "between the dates specified by --starting-snapshot " - "and --ending-snapshot. Snapshots must be specified by YYYY-WeekNumber " - "(e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, " - "(specified with the '--cc-news' flag) this changes to " - "Year-Month (YYYY-MM)", - ) - parser.add_argument( - "--ending-snapshot", - type=str, - default="2020-50", - help="The last snapshot for which WARC urls will be retrieved. " - "Snapshots must be specified by YYYY-WeekNumber " - "(e.g., '2020-50' or '2021-04')", - ) - attach_bool_arg( - parser, - "cc-news", - help_str="Specify --cc-news in order to download WARC URLs for " - "the CC-NEWS dataset instead of the CC-MAIN datasets. If this " - "is specified, then it is assumed that the format for the start " - "and end snapshots is 'YYYY-MM' (Year-Month). All WARC URLs between " - "the specified years and months will be download", - ) - return parser + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--cc-data-domain-prefix", + type=str, + default="https://data.commoncrawl.org", + help="The prefix that will be prepended to each WARC " + "file to create the URL. By default this value is " + " 'https://data.commoncrawl.org'", + ) + parser.add_argument( + "--cc-index-prefix", + type=str, + default="https://index.commoncrawl.org", + help="The prefix of the URL to the Common Crawl index. " + "By default this value is 'https://index.commoncrawl.org'", + ) + parser.add_argument( + "--output-warc-url-file", + type=str, + default=None, + required=True, + help="The output file to which the WARC urls will be written", + ) + parser.add_argument( + "--starting-snapshot", + type=str, + default="2020-50", + help="The starting snapshot to download. All WARC urls will be written " + "between the dates specified by --starting-snapshot " + "and --ending-snapshot. Snapshots must be specified by YYYY-WeekNumber " + "(e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, " + "(specified with the '--cc-news' flag) this changes to " + "Year-Month (YYYY-MM)", + ) + parser.add_argument( + "--ending-snapshot", + type=str, + default="2020-50", + help="The last snapshot for which WARC urls will be retrieved. " + "Snapshots must be specified by YYYY-WeekNumber " + "(e.g., '2020-50' or '2021-04')", + ) + attach_bool_arg( + parser, + "cc-news", + help_str="Specify --cc-news in order to download WARC URLs for " + "the CC-NEWS dataset instead of the CC-MAIN datasets. If this " + "is specified, then it is assumed that the format for the start " + "and end snapshots is 'YYYY-MM' (Year-Month). All WARC URLs between " + "the specified years and months will be download", + ) + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/get_wikipedia_urls.py b/nemo_curator/scripts/get_wikipedia_urls.py index b216851d..8ccb4f40 100644 --- a/nemo_curator/scripts/get_wikipedia_urls.py +++ b/nemo_curator/scripts/get_wikipedia_urls.py @@ -13,43 +13,49 @@ # limitations under the License. import argparse + from nemo_curator.utils.download_utils import get_wikipedia_urls + def main(args): - wikipedia_urls = get_wikipedia_urls(language=args.language, wikidumps_index_prefix=args.wikidumps_index_baseurl) - with open(args.output_url_file, 'w') as output_file: - for url in wikipedia_urls: - output_file.write(url) - output_file.write('\n') + wikipedia_urls = get_wikipedia_urls( + language=args.language, wikidumps_index_prefix=args.wikidumps_index_baseurl + ) + with open(args.output_url_file, "w") as output_file: + for url in wikipedia_urls: + output_file.write(url) + output_file.write("\n") -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Pulls urls pointing to the latest Wikipedia dumps """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--language", - type=str, - default='en', - help="Desired language of the Wikipedia dump", - ) - parser.add_argument( - "--wikidumps-index-baseurl", - type=str, - default='https://dumps.wikimedia.org', - help="The base url for all Wikipedia dumps", - ) - parser.add_argument( - "--output-url-file", - type=str, - default="wikipedia_urls_latest.txt", - help="The output file to which the urls containing " - "the latest dump data will be written", - ) - return parser + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--language", + type=str, + default="en", + help="Desired language of the Wikipedia dump", + ) + parser.add_argument( + "--wikidumps-index-baseurl", + type=str, + default="https://dumps.wikimedia.org", + help="The base url for all Wikipedia dumps", + ) + parser.add_argument( + "--output-url-file", + type=str, + default="wikipedia_urls_latest.txt", + help="The output file to which the urls containing " + "the latest dump data will be written", + ) + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/jaccard_compute.py b/nemo_curator/scripts/jaccard_compute.py index 41400772..f5915716 100644 --- a/nemo_curator/scripts/jaccard_compute.py +++ b/nemo_curator/scripts/jaccard_compute.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import os +import time -from nemo_curator.gpu_deduplication.utils import ( - enable_spilling, - parse_nc_args, -) -from nemo_curator.utils.distributed_utils import get_num_workers, get_client +from nemo_curator.gpu_deduplication.utils import enable_spilling, parse_nc_args from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity +from nemo_curator.utils.distributed_utils import get_client, get_num_workers def main(args): @@ -30,7 +27,9 @@ def main(args): """ OUTPUT_PATH = args.output_dir shuffled_docs_path = args.shuffled_docs_path - output_final_results_path = os.path.join(OUTPUT_PATH, "jaccard_similarity_results.parquet") + output_final_results_path = os.path.join( + OUTPUT_PATH, "jaccard_similarity_results.parquet" + ) client = get_client(args, "gpu") enable_spilling() client.run(enable_spilling) diff --git a/nemo_curator/scripts/jaccard_shuffle.py b/nemo_curator/scripts/jaccard_shuffle.py index 441d99f4..dc5d20f9 100644 --- a/nemo_curator/scripts/jaccard_shuffle.py +++ b/nemo_curator/scripts/jaccard_shuffle.py @@ -15,15 +15,20 @@ import os import time +from nemo_curator.gpu_deduplication.utils import ( + get_client, + get_num_workers, + parse_nc_args, +) +from nemo_curator.modules.fuzzy_dedup import _Shuffle from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( get_text_ddf_from_json_path_with_blocksize, ) -from nemo_curator.gpu_deduplication.utils import get_client, get_num_workers, parse_nc_args -from nemo_curator.modules.fuzzy_dedup import _Shuffle def func(): import cudf + from nemo_curator.modules.fuzzy_dedup import _Shuffle @@ -56,7 +61,7 @@ def main(args): id_fields=["dataset_id", "doc_id"], text_field=args.input_json_text_field, profile_dir=args.profile_path, - int_to_str_id="adlr_id" + int_to_str_id="adlr_id", ) shuffle.shuffle_docs_on_buckets( documents_df=text_ddf, diff --git a/nemo_curator/scripts/make_data_shards.py b/nemo_curator/scripts/make_data_shards.py index c9dc2281..f3beae67 100644 --- a/nemo_curator/scripts/make_data_shards.py +++ b/nemo_curator/scripts/make_data_shards.py @@ -14,18 +14,26 @@ import argparse -from nemo_curator.utils.script_utils import add_distributed_args from nemo_curator.utils.distributed_utils import get_client from nemo_curator.utils.file_utils import reshard_jsonl +from nemo_curator.utils.script_utils import add_distributed_args + def main(args): - client = get_client(args, args.device) + client = get_client(args, args.device) - reshard_jsonl(args.input_data_dir, args.output_resharded_dir, output_file_size=args.output_file_size, start_index=args.start_index, file_prefix=args.prefix) + reshard_jsonl( + args.input_data_dir, + args.output_resharded_dir, + output_file_size=args.output_file_size, + start_index=args.start_index, + file_prefix=args.prefix, + ) -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Makes balanced text files of output size "--block-size" from a directory of input files. The output files will be renamed as output_dir/000.jsonl, output_dir/001.jsonl, ... etc. Users @@ -35,47 +43,47 @@ def attach_args(parser=argparse.ArgumentParser( The size of the input files must be larger than the specified "--block-size" """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - required=True, - help="Input directory consisting of .jsonl file(s)", - ) - parser.add_argument( - "--output-resharded-dir", - type=str, - default=None, - required=True, - help="Output directory to where the sharded " - ".jsonl files will be written", - ) - parser.add_argument( - "--output-file-size", - type=str, - default="100M", - help="Approximate size of output files. Must specify with a string and " - "with the unit K, M or G for kilo, mega or gigabytes", - ) - parser.add_argument( - "--start-index", - type=int, - default=0, - help="Starting index for naming the output files", - ) - parser.add_argument( - "--prefix", - type=str, - default="", - help="Prefix to use to prepend to output file number", - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + required=True, + help="Input directory consisting of .jsonl file(s)", + ) + parser.add_argument( + "--output-resharded-dir", + type=str, + default=None, + required=True, + help="Output directory to where the sharded " ".jsonl files will be written", + ) + parser.add_argument( + "--output-file-size", + type=str, + default="100M", + help="Approximate size of output files. Must specify with a string and " + "with the unit K, M or G for kilo, mega or gigabytes", + ) + parser.add_argument( + "--start-index", + type=int, + default=0, + help="Starting index for naming the output files", + ) + parser.add_argument( + "--prefix", + type=str, + default="", + help="Prefix to use to prepend to output file number", + ) - parser = add_distributed_args(parser) + parser = add_distributed_args(parser) - return parser + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/map_buckets.py b/nemo_curator/scripts/map_buckets.py index e584af3b..522e4f41 100644 --- a/nemo_curator/scripts/map_buckets.py +++ b/nemo_curator/scripts/map_buckets.py @@ -15,12 +15,16 @@ import os import time +from nemo_curator.gpu_deduplication.utils import ( + get_client, + get_num_workers, + parse_nc_args, +) +from nemo_curator.modules.fuzzy_dedup import _MapBuckets from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( get_bucket_ddf_from_parquet_path, get_text_ddf_from_json_path_with_blocksize, ) -from nemo_curator.gpu_deduplication.utils import get_client, get_num_workers, parse_nc_args -from nemo_curator.modules.fuzzy_dedup import _MapBuckets def get_anchor_and_output_map_info( diff --git a/nemo_curator/scripts/minhash_lsh.py b/nemo_curator/scripts/minhash_lsh.py index ecaeacce..fb2c6a90 100644 --- a/nemo_curator/scripts/minhash_lsh.py +++ b/nemo_curator/scripts/minhash_lsh.py @@ -18,10 +18,12 @@ import cudf import dask_cudf import numpy as np + from nemo_curator import LSH from nemo_curator.datasets import DocumentDataset -from nemo_curator.gpu_deduplication.jaccard_utils.doc_id_mapping import \ - convert_str_id_to_int +from nemo_curator.gpu_deduplication.jaccard_utils.doc_id_mapping import ( + convert_str_id_to_int, +) from nemo_curator.gpu_deduplication.utils import create_logger, parse_nc_args from nemo_curator.utils.distributed_utils import get_client diff --git a/nemo_curator/scripts/prepare_fasttext_training_data.py b/nemo_curator/scripts/prepare_fasttext_training_data.py index 29dd6e41..453d17f2 100644 --- a/nemo_curator/scripts/prepare_fasttext_training_data.py +++ b/nemo_curator/scripts/prepare_fasttext_training_data.py @@ -12,114 +12,126 @@ # See the License for the specific language governing permissions and # limitations under the License. -import csv import argparse +import csv + from nemo_curator.datasets import DocumentDataset -from nemo_curator.modules import Modify from nemo_curator.modifiers import FastTextLabelModifier +from nemo_curator.modules import Modify from nemo_curator.utils.distributed_utils import get_client, read_data from nemo_curator.utils.file_utils import get_all_files_paths_under from nemo_curator.utils.script_utils import add_distributed_args + def sample_rows(df, n, seed): - samples = df.sample(frac=n / len(df) + 0.05, random_state=seed) + samples = df.sample(frac=n / len(df) + 0.05, random_state=seed) - return samples.head(n=n, compute=False) + return samples.head(n=n, compute=False) def main(args): - client = get_client(args, args.device) - # Get local path - files = list(get_all_files_paths_under(args.input_data_dir)) - raw_data = read_data(files, file_type="jsonl", backend="pandas") - dataset = DocumentDataset(raw_data) - text_field = args.input_json_field - - # fastText requires each document to be prepended with a special label for training - preprocessing = Modify(FastTextLabelModifier(args.label), text_field=text_field) - labeled_data = preprocessing(dataset) + client = get_client(args, args.device) + # Get local path + files = list(get_all_files_paths_under(args.input_data_dir)) + raw_data = read_data(files, file_type="jsonl", backend="pandas") + dataset = DocumentDataset(raw_data) + text_field = args.input_json_field + + # fastText requires each document to be prepended with a special label for training + preprocessing = Modify(FastTextLabelModifier(args.label), text_field=text_field) + labeled_data = preprocessing(dataset) - samples = sample_rows(labeled_data.df, args.output_num_samples, args.seed) + samples = sample_rows(labeled_data.df, args.output_num_samples, args.seed) - samples[text_field].to_csv(args.output_train_file, single_file=True, encoding="utf-8", header=False, index=False, quoting=csv.QUOTE_NONE, sep="\n") - client.close() + samples[text_field].to_csv( + args.output_train_file, + single_file=True, + encoding="utf-8", + header=False, + index=False, + quoting=csv.QUOTE_NONE, + sep="\n", + ) + client.close() -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Prepare data for training skip-gram classifier with FastText Takes as input a directory of .jsonl files, and writes an output file of samples prepared in order to train a skip-gram classifier with FastText. """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - help="Input directory consisting of .jsonl files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--input-local-data-dir", - type=str, - default=None, - help="Input directory consisting of .jsonl files. " - "Use this argument when a distributed file system is not available.", - ) - parser.add_argument( - "--input-json-field", - type=str, - default='text', - help="The input field within each JSON object on which the filter will " - "operate. By default, the filter will operate on the 'text' " - "field but other fields can be specified such as 'url' or 'id'.", - ) - parser.add_argument( - "--output-num-samples", - type=int, - default=None, - required=True, - help="The number of documents to randomly sample from the dataset and" - " use as training and validation samples to train the" - " skip-gram classifier", - ) - parser.add_argument( - "--output-train-file", - type=str, - default=None, - help="The output file containing prepared samples to train a " - "skip-gram classifier with FastText", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="The random seed to use for sampling from the dataset", - ) - parser.add_argument( - "--label", - type=str, - default=None, - required=True, - help="The label to be used at the beginning of each sample " - "in the output file. For example '__label__hq' could be " - "used for the high-quality (positive) samples", - ) - parser.add_argument( - "--log-dir", - type=str, - default="./log/prepare_filter_data", - help="The output log directory where node and local" - " ranks will write their respective log files", - ) - - parser = add_distributed_args(parser) - - return parser + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + help="Input directory consisting of .jsonl files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--input-local-data-dir", + type=str, + default=None, + help="Input directory consisting of .jsonl files. " + "Use this argument when a distributed file system is not available.", + ) + parser.add_argument( + "--input-json-field", + type=str, + default="text", + help="The input field within each JSON object on which the filter will " + "operate. By default, the filter will operate on the 'text' " + "field but other fields can be specified such as 'url' or 'id'.", + ) + parser.add_argument( + "--output-num-samples", + type=int, + default=None, + required=True, + help="The number of documents to randomly sample from the dataset and" + " use as training and validation samples to train the" + " skip-gram classifier", + ) + parser.add_argument( + "--output-train-file", + type=str, + default=None, + help="The output file containing prepared samples to train a " + "skip-gram classifier with FastText", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The random seed to use for sampling from the dataset", + ) + parser.add_argument( + "--label", + type=str, + default=None, + required=True, + help="The label to be used at the beginning of each sample " + "in the output file. For example '__label__hq' could be " + "used for the high-quality (positive) samples", + ) + parser.add_argument( + "--log-dir", + type=str, + default="./log/prepare_filter_data", + help="The output log directory where node and local" + " ranks will write their respective log files", + ) + + parser = add_distributed_args(parser) + + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/prepare_task_data.py b/nemo_curator/scripts/prepare_task_data.py index 2cc017b5..001eb73e 100644 --- a/nemo_curator/scripts/prepare_task_data.py +++ b/nemo_curator/scripts/prepare_task_data.py @@ -13,9 +13,10 @@ # limitations under the License. import argparse -import yaml import pickle +import yaml + import nemo_curator from nemo_curator.tasks.downstream_task import import_task from nemo_curator.utils.distributed_utils import get_client @@ -23,28 +24,29 @@ def main(args): - client = get_client(args, args.device) - # Read in config file - with open(args.task_config_file, 'r') as config_file: - task_params = yaml.load(config_file, Loader=yaml.FullLoader) + client = get_client(args, args.device) + # Read in config file + with open(args.task_config_file, "r") as config_file: + task_params = yaml.load(config_file, Loader=yaml.FullLoader) + + # Generate n-grams for all tasks + task_list = [] + for task in task_params["tasks"]: + print(f"Generating N-grams for task {task['name']}") + task_class = import_task(task["name"]) + task_object = task_class(**task["params"]) + task_list.append(task_object) - # Generate n-grams for all tasks - task_list = [] - for task in task_params['tasks']: - print(f"Generating N-grams for task {task['name']}") - task_class = import_task(task['name']) - task_object = task_class(**task['params']) - task_list.append(task_object) - - decontaminator = nemo_curator.TaskDecontamination(task_list) - all_ngrams = decontaminator.prepare_task_ngram_count() + decontaminator = nemo_curator.TaskDecontamination(task_list) + all_ngrams = decontaminator.prepare_task_ngram_count() - with open(args.output_task_ngrams, 'wb') as fp: - pickle.dump(all_ngrams, fp) + with open(args.output_task_ngrams, "wb") as fp: + pickle.dump(all_ngrams, fp) -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Computes N-grams from input downstream task validation datasets. Takes in an input configuration file (defaults can be found under the config directory under the root directory of the repository) and @@ -52,31 +54,32 @@ def attach_args(parser=argparse.ArgumentParser( used by the program find_matching_ngrams which will search for matching N-grams in the input training dataset. """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--task-config-file", - type=str, - default=None, - required=True, - help="YAML configuration file that contains task information. " - "YAML files for already implemented tasks can be found in the config " - "directory that is located in the root directory of this repository.", - ) - parser.add_argument( - "--output-task-ngrams", - type=str, - default="./task_ngrams.pkl", - help="N-grams computed from input task data. N-grams are stored " - "as keys to a dictionary and the values of the dictionary " - "are the frequencies of which the n-grams occurr within a " - "training dataset (they are initialized to zero within this program)", - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--task-config-file", + type=str, + default=None, + required=True, + help="YAML configuration file that contains task information. " + "YAML files for already implemented tasks can be found in the config " + "directory that is located in the root directory of this repository.", + ) + parser.add_argument( + "--output-task-ngrams", + type=str, + default="./task_ngrams.pkl", + help="N-grams computed from input task data. N-grams are stored " + "as keys to a dictionary and the values of the dictionary " + "are the frequencies of which the n-grams occurr within a " + "training dataset (they are initialized to zero within this program)", + ) - parser = add_distributed_args(parser) + parser = add_distributed_args(parser) - return parser + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/remove_matching_ngrams.py b/nemo_curator/scripts/remove_matching_ngrams.py index 10a59c41..c298e7b3 100644 --- a/nemo_curator/scripts/remove_matching_ngrams.py +++ b/nemo_curator/scripts/remove_matching_ngrams.py @@ -17,57 +17,74 @@ import nemo_curator from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.distributed_utils import ( - read_data, - write_to_disk, - get_client -) +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import ( - get_all_files_paths_under, expand_outdir_and_mkdir, - get_batched_files + get_all_files_paths_under, + get_batched_files, ) from nemo_curator.utils.script_utils import add_distributed_args def main(args): - client = get_client(args, args.device) - - output_tdd_dir = expand_outdir_and_mkdir(args.output_task_deduped_dir) - output_rm_doc_dir = None - if args.output_removed_doc_dir is not None: - output_rm_doc_dir = expand_outdir_and_mkdir(args.output_removed_doc_dir) - - # Each rank read in the task data - print(f"Reading in matched n-grams from {args.input_matched_ngrams}") - with open(args.input_matched_ngrams, 'rb') as fp: - matched_ngram_data = pickle.load(fp) - - # Unpack the results from find_matched_ngrams - matched_ngrams = matched_ngram_data['matched-ngrams'] - ngrams_freq = matched_ngram_data['ngrams-freq'] - max_ngram_size = matched_ngram_data['max-ngram-size'] - - decontaminator = nemo_curator.TaskDecontamination([], - text_field=args.input_text_field, - max_ngram_size=max_ngram_size, - max_matches=args.match_threshold, - max_splits=args.max_document_splits, - removed_dir=output_rm_doc_dir - ) - - files = list(get_all_files_paths_under(args.input_data_dir)) - for files in get_batched_files(args.input_data_dir, output_tdd_dir, args.input_file_type, batch_size=args.batch_size): - dataset = DocumentDataset(read_data(files, file_type=args.input_file_type, backend="pandas", add_filename=True)) - decontaminated_dataset = decontaminator.remove_matching_ngrams(matched_ngrams, ngrams_freq, dataset) - write_to_disk(decontaminated_dataset.df, output_tdd_dir, write_to_filename=True, output_type=args.output_file_type) - print(f"Finished decontaminating {len(files)} files") - - print("Finished decontaminating all files") - - -def attach_args(parser=argparse.ArgumentParser( - """ + client = get_client(args, args.device) + + output_tdd_dir = expand_outdir_and_mkdir(args.output_task_deduped_dir) + output_rm_doc_dir = None + if args.output_removed_doc_dir is not None: + output_rm_doc_dir = expand_outdir_and_mkdir(args.output_removed_doc_dir) + + # Each rank read in the task data + print(f"Reading in matched n-grams from {args.input_matched_ngrams}") + with open(args.input_matched_ngrams, "rb") as fp: + matched_ngram_data = pickle.load(fp) + + # Unpack the results from find_matched_ngrams + matched_ngrams = matched_ngram_data["matched-ngrams"] + ngrams_freq = matched_ngram_data["ngrams-freq"] + max_ngram_size = matched_ngram_data["max-ngram-size"] + + decontaminator = nemo_curator.TaskDecontamination( + [], + text_field=args.input_text_field, + max_ngram_size=max_ngram_size, + max_matches=args.match_threshold, + max_splits=args.max_document_splits, + removed_dir=output_rm_doc_dir, + ) + + files = list(get_all_files_paths_under(args.input_data_dir)) + for files in get_batched_files( + args.input_data_dir, + output_tdd_dir, + args.input_file_type, + batch_size=args.batch_size, + ): + dataset = DocumentDataset( + read_data( + files, + file_type=args.input_file_type, + backend="pandas", + add_filename=True, + ) + ) + decontaminated_dataset = decontaminator.remove_matching_ngrams( + matched_ngrams, ngrams_freq, dataset + ) + write_to_disk( + decontaminated_dataset.df, + output_tdd_dir, + write_to_filename=True, + output_type=args.output_file_type, + ) + print(f"Finished decontaminating {len(files)} files") + + print("Finished decontaminating all files") + + +def attach_args( + parser=argparse.ArgumentParser( + """ Using the matching n-grams find by nemo_curator/scripts/find_matching_ngrams.py (provided by the argument --input-matched-ngrams), @@ -75,89 +92,90 @@ def attach_args(parser=argparse.ArgumentParser( splitting documents containing the match. If a document is split more than --max-splits times, it is removed from the corpus. """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - help="Input directory consisting of .jsonl files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--input-text-field", - type=str, - default='text', - help="The name of the field within each datapoint object of the input " - "file that contains the text.", - ) - parser.add_argument( - "--input-matched-ngrams", - type=str, - default=None, - required=True, - help="Input dictionary (.pkl file), that contains matched " - "n-gram data from the find_matching_ngrams code", - ) - parser.add_argument( - "--output-task-deduped-dir", - type=str, - default=None, - required=True, - help="Output directory to where task-deduplicated (split) " - "documents will be written", - ) - parser.add_argument( - "--output-removed-doc-dir", - type=str, - default=None, - help="Output directory to where removed documents will be written. " - "Documents will be removed from the corpus if they are split more " - "than --max-document-splits number of times, or if the user specifies " - "that they be removed via the flag, --remove-split-docs", - ) - parser.add_argument( - "--match-threshold", - type=int, - default=10, - help="A threshold that determines if a matched n-gram will be " - "considered for removal in remove_matching_ngrams. N-grams that " - "exceed this number of matches in the training dataset will not be " - "considered during the removal stage", - ) - parser.add_argument( - "--max-document-splits", - type=int, - default=10, - help="A threshold used to determine if a document should be removed " - "from the corpus if it is split more than " - "--max-document-splits number of times", - ) - parser.add_argument( - "--input-file-type", - type=str, - default="jsonl", - help="File type of the dataset to be read in. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--output-file-type", - type=str, - default="jsonl", - help="File type the dataset will be written to. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--batch-size", - type=int, - default=64, - help="Number of files to read into memory at a time.", - ) - - parser = add_distributed_args(parser) - - return parser + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + help="Input directory consisting of .jsonl files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--input-text-field", + type=str, + default="text", + help="The name of the field within each datapoint object of the input " + "file that contains the text.", + ) + parser.add_argument( + "--input-matched-ngrams", + type=str, + default=None, + required=True, + help="Input dictionary (.pkl file), that contains matched " + "n-gram data from the find_matching_ngrams code", + ) + parser.add_argument( + "--output-task-deduped-dir", + type=str, + default=None, + required=True, + help="Output directory to where task-deduplicated (split) " + "documents will be written", + ) + parser.add_argument( + "--output-removed-doc-dir", + type=str, + default=None, + help="Output directory to where removed documents will be written. " + "Documents will be removed from the corpus if they are split more " + "than --max-document-splits number of times, or if the user specifies " + "that they be removed via the flag, --remove-split-docs", + ) + parser.add_argument( + "--match-threshold", + type=int, + default=10, + help="A threshold that determines if a matched n-gram will be " + "considered for removal in remove_matching_ngrams. N-grams that " + "exceed this number of matches in the training dataset will not be " + "considered during the removal stage", + ) + parser.add_argument( + "--max-document-splits", + type=int, + default=10, + help="A threshold used to determine if a document should be removed " + "from the corpus if it is split more than " + "--max-document-splits number of times", + ) + parser.add_argument( + "--input-file-type", + type=str, + default="jsonl", + help="File type of the dataset to be read in. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--output-file-type", + type=str, + default="jsonl", + help="File type the dataset will be written to. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Number of files to read into memory at a time.", + ) + + parser = add_distributed_args(parser) + + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/separate_by_metadata.py b/nemo_curator/scripts/separate_by_metadata.py index 617ace79..f7aab294 100644 --- a/nemo_curator/scripts/separate_by_metadata.py +++ b/nemo_curator/scripts/separate_by_metadata.py @@ -13,109 +13,119 @@ # limitations under the License. import argparse -import shutil import json +import shutil -from nemo_curator.utils.file_utils import separate_by_metadata, get_all_files_paths_under, expand_outdir_and_mkdir from nemo_curator.utils.distributed_utils import get_client, read_data -from nemo_curator.utils.script_utils import ( - attach_bool_arg, - add_distributed_args, +from nemo_curator.utils.file_utils import ( + expand_outdir_and_mkdir, + get_all_files_paths_under, + separate_by_metadata, ) +from nemo_curator.utils.script_utils import add_distributed_args, attach_bool_arg def main(args): - client = get_client(args, args.device) + client = get_client(args, args.device) + + files = get_all_files_paths_under(args.input_data_dir) + input_data = read_data( + files, file_type=args.input_file_type, backend="pandas", add_filename=True + ) - files = get_all_files_paths_under(args.input_data_dir) - input_data = read_data(files, file_type=args.input_file_type, backend="pandas", add_filename=True) - - output_dir = expand_outdir_and_mkdir(args.output_data_dir) + output_dir = expand_outdir_and_mkdir(args.output_data_dir) - metadata_field = args.input_metadata_field - print(f"Beginning metadata separation for {metadata_field}") - metadata_distribution = separate_by_metadata(input_data, output_dir, metadata_field, remove_metadata=args.remove_metadata_field, output_type=args.output_file_type).compute() - print(f"Finished metadata separation for {metadata_field}") + metadata_field = args.input_metadata_field + print(f"Beginning metadata separation for {metadata_field}") + metadata_distribution = separate_by_metadata( + input_data, + output_dir, + metadata_field, + remove_metadata=args.remove_metadata_field, + output_type=args.output_file_type, + ).compute() + print(f"Finished metadata separation for {metadata_field}") - with open(args.output_metadata_distribution, 'w') as fp: - json.dump(metadata_distribution, fp) - - if args.remove_input_dir: - print(f"Removing all files in {args.input_data_dir}") - shutil.rmtree(args.input_data_dir) - print(f"Finished removing all files in {args.input_data_dir}") + with open(args.output_metadata_distribution, "w") as fp: + json.dump(metadata_distribution, fp) + if args.remove_input_dir: + print(f"Removing all files in {args.input_data_dir}") + shutil.rmtree(args.input_data_dir) + print(f"Finished removing all files in {args.input_data_dir}") -def attach_args(parser=argparse.ArgumentParser( - """ +def attach_args( + parser=argparse.ArgumentParser( + """ Spits a dataset into subdirectories based on metadata values """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - help="Input directory consisting of .jsonl files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--input-metadata-field", - type=str, - default='language', - help="The name of the field within each datapoint object of the input " - "file that the dataset should be separated by.", - ) - parser.add_argument( - "--output-metadata-distribution", - type=str, - help="Output json file containing the frequency of documents " - "that occur for a particular metadata.", - ) - parser.add_argument( - "--output-data-dir", - type=str, - required=True, - help="The output directory to where the metadata-separated " - "files will be written. Each file will be written to its " - "respective metadata directory that is a sub-directory " - "of this directory", - ) - attach_bool_arg( - parser, - "remove-metadata-field", - default=False, - help_str="Option of whether to remove the metadata field " - "after filtering. Useful only in the case in which one metadata " - "is desired to be separated from the others", - ) - attach_bool_arg( - parser, - "remove-input-dir", - default=False, - help_str="Specify '--remove-input-dir' to remove the original " - "input directory. This is false by default.", - ) - parser.add_argument( - "--input-file-type", - type=str, - default="jsonl", - help="File type of the dataset to be read in. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--output-file-type", - type=str, - default="jsonl", - help="File type the dataset will be written to. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + help="Input directory consisting of .jsonl files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--input-metadata-field", + type=str, + default="language", + help="The name of the field within each datapoint object of the input " + "file that the dataset should be separated by.", + ) + parser.add_argument( + "--output-metadata-distribution", + type=str, + help="Output json file containing the frequency of documents " + "that occur for a particular metadata.", + ) + parser.add_argument( + "--output-data-dir", + type=str, + required=True, + help="The output directory to where the metadata-separated " + "files will be written. Each file will be written to its " + "respective metadata directory that is a sub-directory " + "of this directory", + ) + attach_bool_arg( + parser, + "remove-metadata-field", + default=False, + help_str="Option of whether to remove the metadata field " + "after filtering. Useful only in the case in which one metadata " + "is desired to be separated from the others", + ) + attach_bool_arg( + parser, + "remove-input-dir", + default=False, + help_str="Specify '--remove-input-dir' to remove the original " + "input directory. This is false by default.", + ) + parser.add_argument( + "--input-file-type", + type=str, + default="jsonl", + help="File type of the dataset to be read in. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--output-file-type", + type=str, + default="jsonl", + help="File type the dataset will be written to. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) - parser = add_distributed_args(parser) + parser = add_distributed_args(parser) - return parser + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/text_cleaning.py b/nemo_curator/scripts/text_cleaning.py index 3a535cd3..5687cfc1 100644 --- a/nemo_curator/scripts/text_cleaning.py +++ b/nemo_curator/scripts/text_cleaning.py @@ -15,34 +15,52 @@ import argparse import nemo_curator -from nemo_curator.modifiers import UnicodeReformatter from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.file_utils import ( - expand_outdir_and_mkdir, - get_batched_files, -) +from nemo_curator.modifiers import UnicodeReformatter from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.file_utils import expand_outdir_and_mkdir, get_batched_files from nemo_curator.utils.script_utils import add_distributed_args def main(args): - client = get_client(args, args.device) + client = get_client(args, args.device) + + # Make the output directories + output_clean_dir = expand_outdir_and_mkdir(args.output_clean_dir) - # Make the output directories - output_clean_dir = expand_outdir_and_mkdir(args.output_clean_dir) + cleaner = nemo_curator.Modify( + UnicodeReformatter(), text_field=args.input_text_field + ) - cleaner = nemo_curator.Modify(UnicodeReformatter(), text_field=args.input_text_field) + for files in get_batched_files( + args.input_data_dir, + output_clean_dir, + args.input_file_type, + batch_size=args.batch_size, + ): + dataset = DocumentDataset( + read_data( + files, + file_type=args.input_file_type, + backend="pandas", + add_filename=True, + ) + ) + cleaned_dataset = cleaner(dataset) + write_to_disk( + cleaned_dataset.df, + output_clean_dir, + write_to_filename=True, + output_type=args.output_file_type, + ) + print(f"Finished reformatting {len(files)} files") - for files in get_batched_files(args.input_data_dir, output_clean_dir, args.input_file_type, batch_size=args.batch_size): - dataset = DocumentDataset(read_data(files, file_type=args.input_file_type, backend="pandas", add_filename=True)) - cleaned_dataset = cleaner(dataset) - write_to_disk(cleaned_dataset.df, output_clean_dir, write_to_filename=True, output_type=args.output_file_type) - print(f"Finished reformatting {len(files)} files") - - print("Finished reformatting all files") + print("Finished reformatting all files") -def attach_args(parser=argparse.ArgumentParser( - """ + +def attach_args( + parser=argparse.ArgumentParser( + """ Text cleaning and language filtering Takes as input a directory consisting of .jsonl files with one @@ -50,55 +68,55 @@ def attach_args(parser=argparse.ArgumentParser( with fixed unicode. Also, performs language filtering using the 'language' field within each JSON object. """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--input-data-dir", - type=str, - default=None, - help="Input directory consisting of .jsonl files that are accessible " - "to all nodes. Use this for a distributed file system", - ) - parser.add_argument( - "--input-text-field", - type=str, - default='text', - help="The name of the field within each datapoint object of the input " - "file that contains the text.", - ) - parser.add_argument( - "--input-file-type", - type=str, - default="jsonl", - help="File type of the dataset to be read in. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--output-clean-dir", - type=str, - default=None, - required=True, - help="The output directory to where the cleaned " - "jsonl files will be written", - ) - parser.add_argument( - "--output-file-type", - type=str, - default="jsonl", - help="File type the dataset will be written to. Supported file formats" - " include 'jsonl' (default), 'pickle', or 'parquet'.", - ) - parser.add_argument( - "--batch-size", - type=int, - default=64, - help="Number of files to read into memory at a time.", - ) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dir", + type=str, + default=None, + help="Input directory consisting of .jsonl files that are accessible " + "to all nodes. Use this for a distributed file system", + ) + parser.add_argument( + "--input-text-field", + type=str, + default="text", + help="The name of the field within each datapoint object of the input " + "file that contains the text.", + ) + parser.add_argument( + "--input-file-type", + type=str, + default="jsonl", + help="File type of the dataset to be read in. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--output-clean-dir", + type=str, + default=None, + required=True, + help="The output directory to where the cleaned " "jsonl files will be written", + ) + parser.add_argument( + "--output-file-type", + type=str, + default="jsonl", + help="File type the dataset will be written to. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Number of files to read into memory at a time.", + ) + + parser = add_distributed_args(parser) - parser = add_distributed_args(parser) - - return parser + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/scripts/train_fasttext.py b/nemo_curator/scripts/train_fasttext.py index 33b0a06c..849949af 100644 --- a/nemo_curator/scripts/train_fasttext.py +++ b/nemo_curator/scripts/train_fasttext.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse +import json +import os import random + import fasttext -import json from sklearn.metrics import confusion_matrix from tqdm import tqdm @@ -25,88 +26,89 @@ def main(args): - # Set the random seed for shuffling - random.seed(args.seed) - - # Read in all samples into a list and shuffle the samples - documents = [] - for ifile in get_all_files_paths_under(args.fasttext_files_dir): - if os.path.splitext(ifile)[-1] == '.txt': - with open(ifile, 'r') as fp: - documents += fp.readlines() - random.shuffle(documents) - - # Get total number of samples - nlines = len(documents) - - num_train = round(nlines * args.validation_split) - - # Split into training and validation samples - train_samples = documents[:num_train] - valid_samples = documents[num_train:] - - # Write out the training and validation samples - with open(args.output_train_file, 'w') as fp: - for document in train_samples: - fp.write(document) - - with open(args.output_validation_file, 'w') as fp: - for document in valid_samples: - fp.write(document) - - # Train the model - model = fasttext.train_supervised( - input=args.output_train_file, - lr=args.learning_rate, - dim=args.word_vector_dim, - epoch=args.num_epochs, - wordNgrams=args.wordNgrams, - ) - - # Save the classifier as a FastText model - model.save_model(args.output_model) - - if args.output_predictions is not None: - fout = open(args.output_predictions, 'wb') - - # Read in the model and compute accuracy and other metrics on the data - hq_label = args.high_quality_label - prds, lbls = [], [] - with open(args.output_validation_file, 'r') as f: - for line in tqdm(f.readlines()): - # Split the text and the label - label_t, doc = line.split(' ', 1) - doc = doc.rstrip() - labels_p, scores = model.predict(doc, k=2) - # Write the predictions to file - if args.output_predictions is not None: - line = { - 'text': doc, - 'label': label_t, - f'{labels_p[0]}': scores[0], - f'{labels_p[1]}': scores[1], - } - myjson = json.dumps(line, ensure_ascii=False) - fout.write(myjson.encode('utf-8')) - fout.write('\n'.encode('utf-8')) - # Save predictions and labels - prds.append(1) if labels_p[0] == hq_label else prds.append(0) - lbls.append(1) if label_t == hq_label else lbls.append(0) - - # Print out the metrics computed on the validation data - tn, fp, fn, tp = confusion_matrix(prds, lbls).ravel() - print(f"TN={tn} FP={fp} FN={fn} TP={tp}") - - accuracy = (tp + tn) / (tp + tn + fp + fn) - precision = tp / (tp + fp) - recall = tp / (tp + fn) - f1 = 2 * tp / (2 * tp + fp + fn) - - print(f"Acc={accuracy} Prec={precision} Rec={recall} f1={f1}") - - -def attach_args(parser=argparse.ArgumentParser( - """ + # Set the random seed for shuffling + random.seed(args.seed) + + # Read in all samples into a list and shuffle the samples + documents = [] + for ifile in get_all_files_paths_under(args.fasttext_files_dir): + if os.path.splitext(ifile)[-1] == ".txt": + with open(ifile, "r") as fp: + documents += fp.readlines() + random.shuffle(documents) + + # Get total number of samples + nlines = len(documents) + + num_train = round(nlines * args.validation_split) + + # Split into training and validation samples + train_samples = documents[:num_train] + valid_samples = documents[num_train:] + + # Write out the training and validation samples + with open(args.output_train_file, "w") as fp: + for document in train_samples: + fp.write(document) + + with open(args.output_validation_file, "w") as fp: + for document in valid_samples: + fp.write(document) + + # Train the model + model = fasttext.train_supervised( + input=args.output_train_file, + lr=args.learning_rate, + dim=args.word_vector_dim, + epoch=args.num_epochs, + wordNgrams=args.wordNgrams, + ) + + # Save the classifier as a FastText model + model.save_model(args.output_model) + + if args.output_predictions is not None: + fout = open(args.output_predictions, "wb") + + # Read in the model and compute accuracy and other metrics on the data + hq_label = args.high_quality_label + prds, lbls = [], [] + with open(args.output_validation_file, "r") as f: + for line in tqdm(f.readlines()): + # Split the text and the label + label_t, doc = line.split(" ", 1) + doc = doc.rstrip() + labels_p, scores = model.predict(doc, k=2) + # Write the predictions to file + if args.output_predictions is not None: + line = { + "text": doc, + "label": label_t, + f"{labels_p[0]}": scores[0], + f"{labels_p[1]}": scores[1], + } + myjson = json.dumps(line, ensure_ascii=False) + fout.write(myjson.encode("utf-8")) + fout.write("\n".encode("utf-8")) + # Save predictions and labels + prds.append(1) if labels_p[0] == hq_label else prds.append(0) + lbls.append(1) if label_t == hq_label else lbls.append(0) + + # Print out the metrics computed on the validation data + tn, fp, fn, tp = confusion_matrix(prds, lbls).ravel() + print(f"TN={tn} FP={fp} FN={fn} TP={tp}") + + accuracy = (tp + tn) / (tp + tn + fp + fn) + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1 = 2 * tp / (2 * tp + fp + fn) + + print(f"Acc={accuracy} Prec={precision} Rec={recall} f1={f1}") + + +def attach_args( + parser=argparse.ArgumentParser( + """ Train a skip-gram quality classifier with FastText Takes as input files with prepared samples for training @@ -114,92 +116,92 @@ def attach_args(parser=argparse.ArgumentParser( classifier on the input samples and writes the trained classifier out to disk as a FastText model. """, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -)): - parser.add_argument( - "--fasttext-files-dir", - type=str, - default=None, - required=True, - help="The input directory containing the file(s) " - "containing the prepared FastText samples", - ) - parser.add_argument( - "--high-quality-label", - type=str, - default="__label__hq", - help="The label assigned to the high quality samples " - "when preparing the data", - ) - parser.add_argument( - "--seed", - type=int, - default=1992, - help="The seed used for randomly shuffling the documents", - ) - parser.add_argument( - "--output-train-file", - type=str, - default="./fasttext_samples.train", - help="The concatenated, shuffled samples used " - "to train the skip-gram classifier", - ) - parser.add_argument( - "--output-validation-file", - type=str, - default="./fasttext_samples.valid", - help="The concatenated, shuffled samples used to " - "for computing validation metrics", - ) - parser.add_argument( - "--validation-split", - type=float, - default=0.9, - help="The training validation split", - ) - parser.add_argument( - "--output-model", - type=str, - default=None, - required=True, - help="The output trained skip-gram classifier written " - "as a FastText model", - ) - parser.add_argument( - "--wordNgrams", - type=int, - default=2, - help="The size of the word n-gram used to train the classifier " - "(default is bigram)", - ) - parser.add_argument( - "--learning-rate", - type=float, - default=0.1, - help="The learning rate used to train the classifier", - ) - parser.add_argument( - "--num-epochs", - type=int, - default=5, - help="Number of epochs used to train the classifier", - ) - parser.add_argument( - "--word-vector-dim", - type=int, - default=100, - help="Size of word vectors to be computed by the model", - ) - parser.add_argument( - "--output-predictions", - type=str, - default=None, - help="The output predictions on the validation data. " - "If a file is not specified, the predictions are not " - "written to file", - ) - return parser + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--fasttext-files-dir", + type=str, + default=None, + required=True, + help="The input directory containing the file(s) " + "containing the prepared FastText samples", + ) + parser.add_argument( + "--high-quality-label", + type=str, + default="__label__hq", + help="The label assigned to the high quality samples " + "when preparing the data", + ) + parser.add_argument( + "--seed", + type=int, + default=1992, + help="The seed used for randomly shuffling the documents", + ) + parser.add_argument( + "--output-train-file", + type=str, + default="./fasttext_samples.train", + help="The concatenated, shuffled samples used " + "to train the skip-gram classifier", + ) + parser.add_argument( + "--output-validation-file", + type=str, + default="./fasttext_samples.valid", + help="The concatenated, shuffled samples used to " + "for computing validation metrics", + ) + parser.add_argument( + "--validation-split", + type=float, + default=0.9, + help="The training validation split", + ) + parser.add_argument( + "--output-model", + type=str, + default=None, + required=True, + help="The output trained skip-gram classifier written " "as a FastText model", + ) + parser.add_argument( + "--wordNgrams", + type=int, + default=2, + help="The size of the word n-gram used to train the classifier " + "(default is bigram)", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=0.1, + help="The learning rate used to train the classifier", + ) + parser.add_argument( + "--num-epochs", + type=int, + default=5, + help="Number of epochs used to train the classifier", + ) + parser.add_argument( + "--word-vector-dim", + type=int, + default=100, + help="Size of word vectors to be computed by the model", + ) + parser.add_argument( + "--output-predictions", + type=str, + default=None, + help="The output predictions on the validation data. " + "If a file is not specified, the predictions are not " + "written to file", + ) + return parser def console_script(): - main(attach_args().parse_args()) + main(attach_args().parse_args()) diff --git a/nemo_curator/tasks/__init__.py b/nemo_curator/tasks/__init__.py index 6091d2bc..5adf3137 100644 --- a/nemo_curator/tasks/__init__.py +++ b/nemo_curator/tasks/__init__.py @@ -13,6 +13,66 @@ # limitations under the License. from .downstream_task import DownstreamTask, import_task -from .metrics import Race, Squad, ArcEasy, ArcChallenge, OpenBookQA, BoolQ, Copa, RTE, MultiRC, WSC, CB, ANLI, Record, COQA, TriviaQA, Quac, WebQA, Drop, WiC, MMLU, BigBenchHard, BigBenchLight, Multilingual, PIQA, Winogrande, Lambada, NumDasc, StoryCloze +from .metrics import ( + ANLI, + CB, + COQA, + MMLU, + PIQA, + RTE, + WSC, + ArcChallenge, + ArcEasy, + BigBenchHard, + BigBenchLight, + BoolQ, + Copa, + Drop, + Lambada, + Multilingual, + MultiRC, + NumDasc, + OpenBookQA, + Quac, + Race, + Record, + Squad, + StoryCloze, + TriviaQA, + WebQA, + WiC, + Winogrande, +) -__all__ = ["DownstreamTask", "import_task", "Race", "Squad", "ArcEasy", "ArcChallenge", "OpenBookQA", "BoolQ", "Copa", "RTE", "MultiRC", "WSC", "CB", "ANLI", "Record", "COQA", "TriviaQA", "Quac", "WebQA", "Drop", "WiC", "MMLU", "BigBenchHard", "BigBenchLight", "Multilingual", "PIQA", "Winogrande", "Lambada", "NumDasc", "StoryCloze"] \ No newline at end of file +__all__ = [ + "DownstreamTask", + "import_task", + "Race", + "Squad", + "ArcEasy", + "ArcChallenge", + "OpenBookQA", + "BoolQ", + "Copa", + "RTE", + "MultiRC", + "WSC", + "CB", + "ANLI", + "Record", + "COQA", + "TriviaQA", + "Quac", + "WebQA", + "Drop", + "WiC", + "MMLU", + "BigBenchHard", + "BigBenchLight", + "Multilingual", + "PIQA", + "Winogrande", + "Lambada", + "NumDasc", + "StoryCloze", +] diff --git a/nemo_curator/tasks/downstream_task.py b/nemo_curator/tasks/downstream_task.py index 3e534398..0f75476d 100644 --- a/nemo_curator/tasks/downstream_task.py +++ b/nemo_curator/tasks/downstream_task.py @@ -20,41 +20,43 @@ class DownstreamTask(ABC): - def __init__(self): - super().__init__() - self._task_name = None - self._ngrams = {} + def __init__(self): + super().__init__() + self._task_name = None + self._ngrams = {} - @abstractmethod - def generate_ngrams(self): - pass + @abstractmethod + def generate_ngrams(self): + pass - @property - def ngrams(self): - return self._ngrams + @property + def ngrams(self): + return self._ngrams - def _update_ngrams(self, text, min_ngram_size=8, max_ngram_size=13): - words, positions = get_words(text) - if len(words) < min_ngram_size: - return + def _update_ngrams(self, text, min_ngram_size=8, max_ngram_size=13): + words, positions = get_words(text) + if len(words) < min_ngram_size: + return - if len(words) < max_ngram_size: - seq = " ".join(words) - if seq not in self._ngrams: - self._ngrams[seq] = 0 + if len(words) < max_ngram_size: + seq = " ".join(words) + if seq not in self._ngrams: + self._ngrams[seq] = 0 - for i in range(len(words) - max_ngram_size + 1): - seq = " ".join(words[i:i + max_ngram_size]) - if seq not in self._ngrams: - self._ngrams[seq] = 0 + for i in range(len(words) - max_ngram_size + 1): + seq = " ".join(words[i : i + max_ngram_size]) + if seq not in self._ngrams: + self._ngrams[seq] = 0 def import_task(task_path): - module_path, task_name = task_path.rsplit(".", 1) - task_module = importlib.import_module(module_path) - task_class = getattr(task_module, task_name) - if not issubclass(task_class, DownstreamTask): - raise ValueError(f"Input iterator {task_class.__name__} " - "must be derived from DownstreamTask" - "defined in nemo_curator.tasks.downstream_task") - return task_class + module_path, task_name = task_path.rsplit(".", 1) + task_module = importlib.import_module(module_path) + task_class = getattr(task_module, task_name) + if not issubclass(task_class, DownstreamTask): + raise ValueError( + f"Input iterator {task_class.__name__} " + "must be derived from DownstreamTask" + "defined in nemo_curator.tasks.downstream_task" + ) + return task_class diff --git a/nemo_curator/tasks/metrics.py b/nemo_curator/tasks/metrics.py index 74fec482..5127d68d 100644 --- a/nemo_curator/tasks/metrics.py +++ b/nemo_curator/tasks/metrics.py @@ -13,551 +13,573 @@ # limitations under the License. import json + from datasets import load_dataset + from nemo_curator.tasks.downstream_task import DownstreamTask from nemo_curator.utils.file_utils import get_all_files_paths_under class Race(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'race' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset(self._task_name, 'all', split='test') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "race" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset(self._task_name, "all", split="test") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class Squad(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'squad' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('squad_v2', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "squad" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("squad_v2", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class ArcEasy(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'arceasy' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('ai2_arc', 'ARC-Easy', split='test') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "arceasy" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("ai2_arc", "ARC-Easy", split="test") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class ArcChallenge(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'arcchallenge' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('ai2_arc', 'ARC-Challenge', split='test') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "arcchallenge" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("ai2_arc", "ARC-Challenge", split="test") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class OpenBookQA(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'openbookqa' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('openbookqa', 'main', split='test') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "openbookqa" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("openbookqa", "main", split="test") - def generate_ngrams(self): - for line in self._dataset: - text = line['question_stem'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question_stem"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class BoolQ(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'boolq' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('super_glue', 'boolq', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "boolq" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("super_glue", "boolq", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class Copa(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'copa' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('super_glue', 'copa', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "copa" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("super_glue", "copa", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['premise'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["premise"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class RTE(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'rte' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('glue', 'rte', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "rte" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("glue", "rte", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line["sentence1"] + '\n' + line["sentence2"] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["sentence1"] + "\n" + line["sentence2"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class MultiRC(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'multirc' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('super_glue', 'multirc', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "multirc" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("super_glue", "multirc", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class WSC(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'wsc' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('super_glue', 'multirc', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "wsc" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("super_glue", "multirc", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class CB(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'cb' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('super_glue', 'cb', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "cb" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("super_glue", "cb", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['premise'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["premise"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class ANLI(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'anli' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('anli') - self._keys = ['test_r1', 'test_r2', 'test_r3'] - - def generate_ngrams(self): - for key in self._keys: - data = self._dataset[key] - for line in data: - try: - text = line['premise'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - text = line['hypothesis'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - except Exception as e: - print('Error:', e) - - return self.ngrams + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "anli" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("anli") + self._keys = ["test_r1", "test_r2", "test_r3"] + + def generate_ngrams(self): + for key in self._keys: + data = self._dataset[key] + for line in data: + try: + text = line["premise"] + self._update_ngrams( + text, self._min_ngram_size, self._max_ngram_size + ) + text = line["hypothesis"] + self._update_ngrams( + text, self._min_ngram_size, self._max_ngram_size + ) + except Exception as e: + print("Error:", e) + + return self.ngrams class Record(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'record' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('super_glue', 'record', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "record" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("super_glue", "record", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['query'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["query"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class COQA(DownstreamTask): - def __init__(self, file_path, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'coqa' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - if file_path is None: - raise Exception("Must provide a path to the coqa.json file") - self._dataset = json.load(open(file_path))['data'] - - def generate_ngrams(self): - for line in self._dataset: - all_questions = line['questions'] - for question in all_questions: - self._update_ngrams( - question['input_text'], - self._min_ngram_size, - self._max_ngram_size, - ) - story = line['story'] - self._update_ngrams(story, self._min_ngram_size, self._max_ngram_size) - - return self.ngrams + def __init__(self, file_path, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "coqa" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + if file_path is None: + raise Exception("Must provide a path to the coqa.json file") + self._dataset = json.load(open(file_path))["data"] + + def generate_ngrams(self): + for line in self._dataset: + all_questions = line["questions"] + for question in all_questions: + self._update_ngrams( + question["input_text"], + self._min_ngram_size, + self._max_ngram_size, + ) + story = line["story"] + self._update_ngrams(story, self._min_ngram_size, self._max_ngram_size) + + return self.ngrams class TriviaQA(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'trivia_qa' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('trivia_qa', 'unfiltered', split='test') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "trivia_qa" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("trivia_qa", "unfiltered", split="test") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class Quac(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'quac' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('quac', split='validation') - - def generate_ngrams(self): - for line in self._dataset: - all_questions = line['questions'] - for question in all_questions: - self._update_ngrams( - question, - self._min_ngram_size, - self._max_ngram_size, - ) + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "quac" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("quac", split="validation") - return self.ngrams + def generate_ngrams(self): + for line in self._dataset: + all_questions = line["questions"] + for question in all_questions: + self._update_ngrams( + question, + self._min_ngram_size, + self._max_ngram_size, + ) + + return self.ngrams class WebQA(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'webqa' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('web_questions', split='test') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "webqa" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("web_questions", split="test") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class Drop(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'drop' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset('drop', split='validation') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "drop" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset("drop", split="validation") - def generate_ngrams(self): - for line in self._dataset: - text = line['question'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["question"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class WiC(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'wic' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset( - path='super_glue', - name='wic', - split='validation', - ) + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "wic" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset( + path="super_glue", + name="wic", + split="validation", + ) - def generate_ngrams(self): - for line in self._dataset: - text = line["sentence1"] + '\n' + line["sentence2"] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["sentence1"] + "\n" + line["sentence2"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class MMLU(DownstreamTask): - def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'mmlu' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._path = path - if self._path is None: - raise Exception("Must provide path that contain " - "MMLU task data in JSONL format") + def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "mmlu" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._path = path + if self._path is None: + raise Exception( + "Must provide path that contain " "MMLU task data in JSONL format" + ) - def generate_ngrams(self): - for ifile in get_all_files_paths_under(self._path): - for iline in open(ifile, 'rb'): - document = json.loads(iline) - text = document['text'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for ifile in get_all_files_paths_under(self._path): + for iline in open(ifile, "rb"): + document = json.loads(iline) + text = document["text"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class BigBenchHard(DownstreamTask): - def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'bigbench_hard' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._path = path - if self._path is None: - raise Exception("Must provide path that contain " - "BigBenchHard task data in JSONL format") + def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "bigbench_hard" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._path = path + if self._path is None: + raise Exception( + "Must provide path that contain " + "BigBenchHard task data in JSONL format" + ) - def generate_ngrams(self): - for ifile in get_all_files_paths_under(self._path): - for iline in open(ifile, 'rb'): - document = json.loads(iline) - text = document['text'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for ifile in get_all_files_paths_under(self._path): + for iline in open(ifile, "rb"): + document = json.loads(iline) + text = document["text"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class BigBenchLight(DownstreamTask): - def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'bigbench_light' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._path = path - if self._path is None: - raise Exception("Must provide path that contain " - "BigBenchLight task data in JSONL format") + def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "bigbench_light" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._path = path + if self._path is None: + raise Exception( + "Must provide path that contain " + "BigBenchLight task data in JSONL format" + ) - def generate_ngrams(self): - for ifile in get_all_files_paths_under(self._path): - for iline in open(ifile, 'rb'): - document = json.loads(iline) - text = document['text'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for ifile in get_all_files_paths_under(self._path): + for iline in open(ifile, "rb"): + document = json.loads(iline) + text = document["text"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class Multilingual(DownstreamTask): - def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'multilingual' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._path = path - if self._path is None: - raise Exception("Must provide path to " - "multilingual task data in JSONL format") + def __init__(self, path=None, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "multilingual" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._path = path + if self._path is None: + raise Exception( + "Must provide path to " "multilingual task data in JSONL format" + ) - def generate_ngrams(self): - for ifile in get_all_files_paths_under(self._path): - for iline in open(ifile, 'rb'): - document = json.loads(iline) - text = document['text'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for ifile in get_all_files_paths_under(self._path): + for iline in open(ifile, "rb"): + document = json.loads(iline) + text = document["text"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class PIQA(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'piqa' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset(self._task_name, split='test') + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "piqa" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset(self._task_name, split="test") - def generate_ngrams(self): - for line in self._dataset: - text = line['goal'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["goal"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class Winogrande(DownstreamTask): - def __init__(self, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'winogrande' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._dataset = load_dataset( - path='winogrande', - name='winogrande_xl', - split='validation', - ) + def __init__(self, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "winogrande" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._dataset = load_dataset( + path="winogrande", + name="winogrande_xl", + split="validation", + ) - def generate_ngrams(self): - for line in self._dataset: - text = line['sentence'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) + def generate_ngrams(self): + for line in self._dataset: + text = line["sentence"] + self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - return self.ngrams + return self.ngrams class Lambada(DownstreamTask): - def __init__(self, file_path, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'lambada' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._file_path = file_path + def __init__(self, file_path, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "lambada" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._file_path = file_path - def generate_ngrams(self): - with open(self._file_path, 'r') as f: - for line in f: - try: - myjson = json.loads(line) - text = myjson['text'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - except Exception as e: - print(f"Error {e}") + def generate_ngrams(self): + with open(self._file_path, "r") as f: + for line in f: + try: + myjson = json.loads(line) + text = myjson["text"] + self._update_ngrams( + text, self._min_ngram_size, self._max_ngram_size + ) + except Exception as e: + print(f"Error {e}") - return self.ngrams + return self.ngrams class NumDasc(DownstreamTask): - def __init__(self, n, file_path, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._n = n - self._task_name = '{n}dasc' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._file_path = file_path - - def generate_ngrams(self): - with open(self._file_path, 'r') as f: - for line in f: - try: - myjson = json.loads(line) - text = myjson['context'] + myjson['completion'] - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - except Exception as e: - print(f"Error {e}") - - return self.ngrams + def __init__(self, n, file_path, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._n = n + self._task_name = "{n}dasc" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._file_path = file_path + + def generate_ngrams(self): + with open(self._file_path, "r") as f: + for line in f: + try: + myjson = json.loads(line) + text = myjson["context"] + myjson["completion"] + self._update_ngrams( + text, self._min_ngram_size, self._max_ngram_size + ) + except Exception as e: + print(f"Error {e}") + + return self.ngrams class StoryCloze(DownstreamTask): - def __init__(self, file_path, min_ngram_size=8, max_ngram_size=13): - super().__init__() - self._task_name = 'story_cloze' - self._min_ngram_size = min_ngram_size - self._max_ngram_size = max_ngram_size - self._file_path = file_path - - def generate_ngrams(self): - with open(self._file_path, 'r') as f: - for line in f: - try: - myjson = json.loads(line) - text = " ".join([ - myjson["InputSentence1"], myjson["InputSentence2"], - myjson["InputSentence3"], myjson["InputSentence4"] - ]) - self._update_ngrams(text, self._min_ngram_size, self._max_ngram_size) - except Exception as e: - print(f"Error {e}") - - return self.ngrams \ No newline at end of file + def __init__(self, file_path, min_ngram_size=8, max_ngram_size=13): + super().__init__() + self._task_name = "story_cloze" + self._min_ngram_size = min_ngram_size + self._max_ngram_size = max_ngram_size + self._file_path = file_path + + def generate_ngrams(self): + with open(self._file_path, "r") as f: + for line in f: + try: + myjson = json.loads(line) + text = " ".join( + [ + myjson["InputSentence1"], + myjson["InputSentence2"], + myjson["InputSentence3"], + myjson["InputSentence4"], + ] + ) + self._update_ngrams( + text, self._min_ngram_size, self._max_ngram_size + ) + except Exception as e: + print(f"Error {e}") + + return self.ngrams diff --git a/nemo_curator/utils/__init__.py b/nemo_curator/utils/__init__.py index fe99e99a..d9155f92 100644 --- a/nemo_curator/utils/__init__.py +++ b/nemo_curator/utils/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/nemo_curator/utils/code_meta.csv b/nemo_curator/utils/code_meta.csv index 7dc2134e..1b82d9b0 100644 --- a/nemo_curator/utils/code_meta.csv +++ b/nemo_curator/utils/code_meta.csv @@ -303,4 +303,4 @@ ,,,,,,,,,300,,,,,,,,,,,, ,json,json,,,,,,,Raymond,,, ,"Manually added JSON and YAML",,1,0.25,,,1,0.5, ,yaml,yaml,,,,,,,Raymond,,, ,"Manually added JSON and YAML",,1,0.25,1000,,1,0.5, -,yml,yaml,,,,,,,Raymond,,, ,"Manually added JSON and YAML",,1,0.25,1000,,1,0.5, \ No newline at end of file +,yml,yaml,,,,,,,Raymond,,, ,"Manually added JSON and YAML",,1,0.25,1000,,1,0.5, diff --git a/nemo_curator/utils/config_utils.py b/nemo_curator/utils/config_utils.py index 12ff4a19..add75d60 100644 --- a/nemo_curator/utils/config_utils.py +++ b/nemo_curator/utils/config_utils.py @@ -12,70 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pydoc import locate + import yaml + import nemo_curator -from nemo_curator.filters import import_filter from nemo_curator.download.doc_builder import ( import_downloader, - import_iterator, import_extractor, + import_iterator, ) +from nemo_curator.filters import import_filter from nemo_curator.utils.file_utils import expand_outdir_and_mkdir -from pydoc import locate + def build_filter(filter_config): - # Import the filter - filter_class = import_filter(filter_config['name']) + # Import the filter + filter_class = import_filter(filter_config["name"]) - # Check if constructor has been provided - if ('params' not in filter_config) or (filter_config['params'] is None): - filter_config['params'] = {} - - doc_filter = filter_class(**filter_config["params"]) + # Check if constructor has been provided + if ("params" not in filter_config) or (filter_config["params"] is None): + filter_config["params"] = {} - if filter_config.get("filter_only", False): - filter_stage = nemo_curator.Filter(doc_filter.keep_document, filter_field=doc_filter.name) - else: - score_field = doc_filter._name if filter_config.get("log_score", False) else None - filter_stage = nemo_curator.ScoreFilter(doc_filter, filter_config.get("input_field"), score_field=score_field) + doc_filter = filter_class(**filter_config["params"]) + + if filter_config.get("filter_only", False): + filter_stage = nemo_curator.Filter( + doc_filter.keep_document, filter_field=doc_filter.name + ) + else: + score_field = ( + doc_filter._name if filter_config.get("log_score", False) else None + ) + filter_stage = nemo_curator.ScoreFilter( + doc_filter, filter_config.get("input_field"), score_field=score_field + ) + + return filter_stage - return filter_stage def build_filter_pipeline(filter_config_file): - # Get the filter config file - with open(filter_config_file, 'r') as config_file: - filter_params = yaml.load(config_file, Loader=yaml.FullLoader) + # Get the filter config file + with open(filter_config_file, "r") as config_file: + filter_params = yaml.load(config_file, Loader=yaml.FullLoader) - filters = [] - text_field = filter_params.get("input_field") - for nc_filter_config in filter_params.get("filters"): - if "input_field" not in nc_filter_config or nc_filter_config["input_field"] is None: - nc_filter_config["input_field"] = text_field - new_filter = build_filter(nc_filter_config) - filters.append(new_filter) + filters = [] + text_field = filter_params.get("input_field") + for nc_filter_config in filter_params.get("filters"): + if ( + "input_field" not in nc_filter_config + or nc_filter_config["input_field"] is None + ): + nc_filter_config["input_field"] = text_field + new_filter = build_filter(nc_filter_config) + filters.append(new_filter) + + return nemo_curator.Sequential(filters) - return nemo_curator.Sequential(filters) def build_downloader(downloader_config_file, default_download_dir=None): - # Get the downloader config file - with open(downloader_config_file, 'r') as config_file: - downloader_params = yaml.load(config_file, Loader=yaml.FullLoader) - - download_class = import_downloader(downloader_params['download_module']) - no_download_dir = ("download_dir" not in downloader_params['download_params']) or (downloader_params['download_params'] is None) - if no_download_dir and default_download_dir: - downloader_params['download_params']['download_dir'] = default_download_dir - expand_outdir_and_mkdir(downloader_params['download_params']['download_dir']) - downloader = download_class(**downloader_params['download_params']) - - iterator_class = import_iterator(downloader_params['iterator_module']) - iterator = iterator_class(**downloader_params['iterator_params']) - - extractor_class = import_extractor(downloader_params['extract_module']) - extractor = extractor_class(**downloader_params['extract_params']) - - dataset_format = {} - for field, field_type in downloader_params["format"].items(): - dataset_format[field] = locate(field_type) - - return downloader, iterator, extractor, dataset_format \ No newline at end of file + # Get the downloader config file + with open(downloader_config_file, "r") as config_file: + downloader_params = yaml.load(config_file, Loader=yaml.FullLoader) + + download_class = import_downloader(downloader_params["download_module"]) + no_download_dir = ("download_dir" not in downloader_params["download_params"]) or ( + downloader_params["download_params"] is None + ) + if no_download_dir and default_download_dir: + downloader_params["download_params"]["download_dir"] = default_download_dir + expand_outdir_and_mkdir(downloader_params["download_params"]["download_dir"]) + downloader = download_class(**downloader_params["download_params"]) + + iterator_class = import_iterator(downloader_params["iterator_module"]) + iterator = iterator_class(**downloader_params["iterator_params"]) + + extractor_class = import_extractor(downloader_params["extract_module"]) + extractor = extractor_class(**downloader_params["extract_params"]) + + dataset_format = {} + for field, field_type in downloader_params["format"].items(): + dataset_format[field] = locate(field_type) + + return downloader, iterator, extractor, dataset_format diff --git a/nemo_curator/utils/constants.py b/nemo_curator/utils/constants.py index c374aa1f..d14bc808 100644 --- a/nemo_curator/utils/constants.py +++ b/nemo_curator/utils/constants.py @@ -13,12 +13,13 @@ # limitations under the License. import re -import regex +import regex -end_marks = (".", "?", "!", "\"", "\'") +end_marks = (".", "?", "!", '"', "'") ellipsis_marks = set( - ["...", "[...]", "…", "(...)", "[…]", "-»", "read more..", "read more"]) + ["...", "[...]", "…", "(...)", "[…]", "-»", "read more..", "read more"] +) policy_substrings = [ "terms of use", "privacy policy", @@ -54,27 +55,30 @@ "settings, you agree to this use. AcceptRead More".lower(), ] white_space_list = ["\t", "\n", "\r", "\b", " "] -common_english_words = set( - ['the', 'be', 'to', 'of', 'and', 'that', 'have', 'with']) -bullet_list = set([ - '•', - '‣', - '⁃', - '⁌', - '⁍', - '∙', - '○', - '●', - '◘', - '◦', - '⦾', - '⦿', -]) +common_english_words = set(["the", "be", "to", "of", "and", "that", "have", "with"]) +bullet_list = set( + [ + "•", + "‣", + "⁃", + "⁌", + "⁍", + "∙", + "○", + "●", + "◘", + "◦", + "⦾", + "⦿", + ] +) regex_alpha = regex.compile("[[:alpha:]]") regex_digit = regex.compile("[[:digit:]]") -regex_alphanum = re.compile('[a-zA-Z0-9\n?!,.]') -regex_url = re.compile('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|' - '(?:%[0-9a-fA-F][0-9a-fA-F]))+') +regex_alphanum = re.compile("[a-zA-Z0-9\n?!,.]") +regex_url = re.compile( + "http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|" + "(?:%[0-9a-fA-F][0-9a-fA-F]))+" +) regex_paren = re.compile(r"{|}|⟨|⟩|\[|\]|\(|\)") -regex_hash = re.compile('#+') +regex_hash = re.compile("#+") diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index 6a73aa2c..71fa1cdc 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -16,19 +16,20 @@ os.environ["RAPIDS_NO_INITIALIZE"] = "1" import warnings +from pathlib import Path +from typing import Union + import cudf import dask.dataframe as dd import dask_cudf import pandas as pd from dask.distributed import Client, LocalCluster, get_worker from dask_cuda import LocalCUDACluster -from typing import Union -from pathlib import Path class DotDict: def __init__(self, d): - self.__dict__['_data'] = d + self.__dict__["_data"] = d def __getattr__(self, name): if name in self._data: @@ -40,6 +41,7 @@ def __getattr__(self, name): class NoWorkerError(Exception): pass + def start_dask_gpu_local_cluster(args) -> Client: """ This function sets up a Dask cluster across all the @@ -384,7 +386,9 @@ def single_partition_write_with_filename(df, output_file_dir, output_type="jsonl if output_type == "jsonl": output_file_path = output_file_path + ".jsonl" if isinstance(df, pd.DataFrame): - df.to_json(output_file_path, orient="records", lines=True, force_ascii=False) + df.to_json( + output_file_path, orient="records", lines=True, force_ascii=False + ) else: # See open issue here: https://github.com/rapidsai/cudf/issues/15211 # df.to_json( @@ -439,9 +443,13 @@ def write_to_disk(df, output_file_dir, write_to_filename=False, output_type="jso if isinstance(df, dask_cudf.DataFrame): # See open issue here: https://github.com/rapidsai/cudf/issues/15211 # df.to_json(output_file_dir, orient="records", lines=True, engine="cudf", force_ascii=False) - df.to_json(output_file_dir, orient="records", lines=True, force_ascii=False) + df.to_json( + output_file_dir, orient="records", lines=True, force_ascii=False + ) else: - df.to_json(output_file_dir, orient="records", lines=True, force_ascii=False) + df.to_json( + output_file_dir, orient="records", lines=True, force_ascii=False + ) elif output_type == "parquet": df.to_parquet(output_file_dir, write_index=False) else: diff --git a/nemo_curator/utils/download_utils.py b/nemo_curator/utils/download_utils.py index 279beede..7c33c1ec 100644 --- a/nemo_curator/utils/download_utils.py +++ b/nemo_curator/utils/download_utils.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import requests import json +import os +import subprocess import zlib -from urllib.parse import urljoin -from datetime import datetime, timedelta from collections import OrderedDict +from datetime import datetime, timedelta from typing import List, Optional +from urllib.parse import urljoin + +import requests from bs4 import BeautifulSoup -import subprocess def get_main_warc_paths( @@ -30,77 +31,88 @@ def get_main_warc_paths( end_snapshot, prefix="https://data.commoncrawl.org", ): - beg_year, beg_week = list(map(int, start_snapshot.split('-'))) - end_year, end_week = list(map(int, end_snapshot.split('-'))) - start_date = datetime.fromisocalendar(beg_year, beg_week, 1) - end_date = datetime.fromisocalendar(end_year, end_week, 1) + beg_year, beg_week = list(map(int, start_snapshot.split("-"))) + end_year, end_week = list(map(int, end_snapshot.split("-"))) + start_date = datetime.fromisocalendar(beg_year, beg_week, 1) + end_date = datetime.fromisocalendar(end_year, end_week, 1) - if beg_year < 2013 or end_year < 2013: - print("Warning: Only snapshots after 2013 are supported by this script") + if beg_year < 2013 or end_year < 2013: + print("Warning: Only snapshots after 2013 are supported by this script") - total_prefix = urljoin(prefix, 'crawl-data/CC-MAIN') + total_prefix = urljoin(prefix, "crawl-data/CC-MAIN") - warc_paths = [] - for snapshot in snapshot_index: - date = list(map(int, snapshot['id'].split('-')[2:])) + warc_paths = [] + for snapshot in snapshot_index: + date = list(map(int, snapshot["id"].split("-")[2:])) - if len(date) == 2: - year, week = date - else: - continue + if len(date) == 2: + year, week = date + else: + continue - if year >= 2013: - curr_date = datetime.fromisocalendar(year, week, 1) - if curr_date >= start_date and curr_date <= end_date: - warc_path = f'{total_prefix}-{year}-{week:02d}/warc.paths.gz' - warc_paths.append(warc_path) + if year >= 2013: + curr_date = datetime.fromisocalendar(year, week, 1) + if curr_date >= start_date and curr_date <= end_date: + warc_path = f"{total_prefix}-{year}-{week:02d}/warc.paths.gz" + warc_paths.append(warc_path) + + return warc_paths - return warc_paths def get_news_warc_paths( start_date, end_date, prefix="https://data.commoncrawl.org", ): - beg = datetime.strptime(start_date, "%Y-%m") - end = datetime.strptime(end_date, "%Y-%m") - - # Get current year and month - today = datetime.now() - - if beg.year < 2016 or end.year > today.year: - print("Warning: WARC paths exist only from 2016-8 to " - f"{today.year}-{today.month}") - total_prefix = urljoin(prefix, 'crawl-data/CC-NEWS') - - # Generate all valid YYYY-MM strings in range - dates = OrderedDict() - for day in range((end - beg).days + 1): - new_date = beg + timedelta(day) - dates[(new_date.year, new_date.month)] = None - - dates = list(dates.keys()) + beg = datetime.strptime(start_date, "%Y-%m") + end = datetime.strptime(end_date, "%Y-%m") + + # Get current year and month + today = datetime.now() + + if beg.year < 2016 or end.year > today.year: + print( + "Warning: WARC paths exist only from 2016-8 to " + f"{today.year}-{today.month}" + ) + total_prefix = urljoin(prefix, "crawl-data/CC-NEWS") + + # Generate all valid YYYY-MM strings in range + dates = OrderedDict() + for day in range((end - beg).days + 1): + new_date = beg + timedelta(day) + dates[(new_date.year, new_date.month)] = None + + dates = list(dates.keys()) + + warc_paths = [] + for year, month in dates: + warc_path = f"{total_prefix}/{year}/{month:02d}/warc.paths.gz" + warc_paths.append(warc_path) - warc_paths = [] - for year, month in dates: - warc_path = f'{total_prefix}/{year}/{month:02d}/warc.paths.gz' - warc_paths.append(warc_path) + return warc_paths - return warc_paths def get_common_crawl_snapshot_index(index_prefix): - index_url = urljoin(index_prefix, "collinfo.json") - index_response = requests.get(index_url) + index_url = urljoin(index_prefix, "collinfo.json") + index_response = requests.get(index_url) + + return json.loads(index_response.content) - return json.loads(index_response.content) -def get_common_crawl_urls(starting_snapshot: str, ending_snapshot: str, data_domain_prefix="https://data.commoncrawl.org", index_prefix="https://index.commoncrawl.org", news=False) -> List[str]: +def get_common_crawl_urls( + starting_snapshot: str, + ending_snapshot: str, + data_domain_prefix="https://data.commoncrawl.org", + index_prefix="https://index.commoncrawl.org", + news=False, +) -> List[str]: """ Retrieves the URLs for all the compressed WARC files between given Common Crawl snapshots Args: - starting_snapshot: The first common crawl snapshot to include. Snapshots must be - specified by YYYY-WeekNumber (e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, + starting_snapshot: The first common crawl snapshot to include. Snapshots must be + specified by YYYY-WeekNumber (e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, (specified with news=True flag) this changes to Year-Month (YYYY-MM). ending_snapshot: The last common crawl snapshot to include. Must be chronologically after the starting snapshot. @@ -110,18 +122,22 @@ def get_common_crawl_urls(starting_snapshot: str, ending_snapshot: str, data_dom Also assumes that the format for the start and end snapshots is 'YYYY-MM' (Year-Month). """ if news: - warc_paths = get_news_warc_paths(starting_snapshot, ending_snapshot, prefix=data_domain_prefix) + warc_paths = get_news_warc_paths( + starting_snapshot, ending_snapshot, prefix=data_domain_prefix + ) else: index = get_common_crawl_snapshot_index(index_prefix) - warc_paths = get_main_warc_paths(index, starting_snapshot, ending_snapshot, prefix=data_domain_prefix) - + warc_paths = get_main_warc_paths( + index, starting_snapshot, ending_snapshot, prefix=data_domain_prefix + ) + common_crawl_urls = [] for path in warc_paths: try: response = requests.get(path.rstrip(), stream=True) data = zlib.decompress(response.content, zlib.MAX_WBITS | 32) - for warc in data.decode('utf-8').split('\n'): - if warc != '': + for warc in data.decode("utf-8").split("\n"): + if warc != "": warc_url = urljoin(data_domain_prefix, warc) common_crawl_urls.append(warc_url) except Exception as e: @@ -131,7 +147,12 @@ def get_common_crawl_urls(starting_snapshot: str, ending_snapshot: str, data_dom return common_crawl_urls -def get_wikipedia_urls(language="en", wikidumps_index_prefix="https://dumps.wikimedia.org", dump_date: Optional[str]=None) -> List[str]: + +def get_wikipedia_urls( + language="en", + wikidumps_index_prefix="https://dumps.wikimedia.org", + dump_date: Optional[str] = None, +) -> List[str]: """ Retrieves all urls pointing to the latest Wikipedia dumps @@ -143,44 +164,45 @@ def get_wikipedia_urls(language="en", wikidumps_index_prefix="https://dumps.wiki """ wiki_index_url = urljoin(wikidumps_index_prefix, f"{language}wiki") if not dump_date: - # First get the index - raw_wiki_index = requests.get(wiki_index_url) - wiki_index = raw_wiki_index.content.decode("utf-8") - wiki_index_parsed = BeautifulSoup(wiki_index, "lxml") - - # Get all dumps available in the index - dumps = wiki_index_parsed.find_all("a") - dump_date = dumps[-2].text + # First get the index + raw_wiki_index = requests.get(wiki_index_url) + wiki_index = raw_wiki_index.content.decode("utf-8") + wiki_index_parsed = BeautifulSoup(wiki_index, "lxml") + + # Get all dumps available in the index + dumps = wiki_index_parsed.find_all("a") + dump_date = dumps[-2].text else: - # A trailing / is needed for the url - dump_date = dump_date + "/" + # A trailing / is needed for the url + dump_date = dump_date + "/" # Get the json dump data - wiki_latest_dump = urljoin(wiki_index_url + '/', dump_date) + wiki_latest_dump = urljoin(wiki_index_url + "/", dump_date) wiki_latest_dump_status = urljoin(wiki_latest_dump, "dumpstatus.json") raw_dump_data = requests.get(wiki_latest_dump_status) try: - dump_data = json.loads(raw_dump_data.content) + dump_data = json.loads(raw_dump_data.content) except json.decoder.JSONDecodeError: - raise ValueError(f"No wikipedia dump found for {dump_date[:-1]}") + raise ValueError(f"No wikipedia dump found for {dump_date[:-1]}") # Get all multistream files within the dump data wikipedia_urls = [] - for ifile in dump_data['jobs']['articlesmultistreamdump']['files']: - if 'xml' in ifile: - url = urljoin(wiki_latest_dump, ifile) - wikipedia_urls.append(url) - + for ifile in dump_data["jobs"]["articlesmultistreamdump"]["files"]: + if "xml" in ifile: + url = urljoin(wiki_latest_dump, ifile) + wikipedia_urls.append(url) + return wikipedia_urls + def get_arxiv_urls(): - command = "s5cmd --request-payer=requester ls s3://arxiv/src/ | grep '.tar'" - result = subprocess.run(command, capture_output=True, text=True, shell=True) - - if result.returncode != 0: - raise RuntimeError(f"Unable to get arxiv urls: {result.stderr}") - - urls = result.stdout.split()[3::4] - urls.sort() - - return urls \ No newline at end of file + command = "s5cmd --request-payer=requester ls s3://arxiv/src/ | grep '.tar'" + result = subprocess.run(command, capture_output=True, text=True, shell=True) + + if result.returncode != 0: + raise RuntimeError(f"Unable to get arxiv urls: {result.stderr}") + + urls = result.stdout.split()[3::4] + urls.sort() + + return urls diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index ecc15780..af3c2513 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import pathlib -from functools import reduce -import pandas as pd -import dask.dataframe as dd -from dask import delayed +from functools import partial, reduce + import dask.bag as db +import dask.dataframe as dd import numpy as np -import json -from functools import partial +import pandas as pd +from dask import delayed from nemo_curator.utils.distributed_utils import single_partition_write_with_filename @@ -58,6 +58,7 @@ def get_all_files_paths_under(root, recurse_subdirectories=True, followlinks=Fal file_ls.sort() return file_ls + # Using this for restarting jobs # can lead to problems when there is an error while # writing a file we can use the offset counter approach @@ -89,7 +90,9 @@ def get_remaining_files(input_file_path, output_file_path, input_file_type): return input_files -def get_batched_files(input_file_path, output_file_path, input_file_type, batch_size=64): +def get_batched_files( + input_file_path, output_file_path, input_file_type, batch_size=64 +): """ This function returns a batch of files that still remain to be processed. @@ -101,11 +104,16 @@ def get_batched_files(input_file_path, output_file_path, input_file_type, batch_ Returns: A batch of files that are not in the output directory. """ - remaining_files = get_remaining_files(input_file_path, output_file_path, input_file_type) + remaining_files = get_remaining_files( + input_file_path, output_file_path, input_file_type + ) for i in range(0, len(remaining_files), batch_size): - yield remaining_files[i:i + batch_size] + yield remaining_files[i : i + batch_size] -def write_dataframe_by_meta(df: pd.DataFrame, output_dir, metadata_field, remove_metadata, output_type): + +def write_dataframe_by_meta( + df: pd.DataFrame, output_dir, metadata_field, remove_metadata, output_type +): counts = df[metadata_field].value_counts().to_dict() for meta_value in counts: @@ -113,17 +121,27 @@ def write_dataframe_by_meta(df: pd.DataFrame, output_dir, metadata_field, remove meta_slice = df[df[metadata_field] == meta_value] if remove_metadata: meta_slice = meta_slice.drop(columns=[metadata_field]) - single_partition_write_with_filename(meta_slice, meta_output_dir, output_type=output_type) - + single_partition_write_with_filename( + meta_slice, meta_output_dir, output_type=output_type + ) + return counts + def merge_counts(first: dict, second: dict): for ngram, count in second.items(): first[ngram] = first.get(ngram, 0) + count - + return first -def separate_by_metadata(df: dd.DataFrame, output_dir, metadata_field, remove_metadata=False, output_type="jsonl") -> dict: + +def separate_by_metadata( + df: dd.DataFrame, + output_dir, + metadata_field, + remove_metadata=False, + output_type="jsonl", +) -> dict: """ Saves the dataframe to subfolders named after a metadata @@ -132,16 +150,22 @@ def separate_by_metadata(df: dd.DataFrame, output_dir, metadata_field, remove_me output_dir: The base directory for which all metadata based subdirs will be created under metadata_field: The metadata field to split on remove_metadata: Whether to remove the metadata from the dataframe when saving it - + Returns: A delayed dictionary mapping each metadata to the count of entries with that metadata value. """ delayed_data = df.to_delayed() - delayed_counts = [delayed(write_dataframe_by_meta)(partition, output_dir, metadata_field, remove_metadata, output_type) for partition in delayed_data] + delayed_counts = [ + delayed(write_dataframe_by_meta)( + partition, output_dir, metadata_field, remove_metadata, output_type + ) + for partition in delayed_data + ] merged_counts = delayed(reduce)(merge_counts, delayed_counts) return merged_counts + def parse_str_of_num_bytes(s, return_str=False): try: power = "kmg".find(s[-1].lower()) + 1 @@ -153,29 +177,33 @@ def parse_str_of_num_bytes(s, return_str=False): else: return int(size) + def _save_jsonl(documents, output_path, start_index=0, max_index=10000, prefix=None): - """ Worker function to write out the data to jsonl files """ + """Worker function to write out the data to jsonl files""" + + def _output_json(document): + myjson = json.dumps(document, ensure_ascii=False) + return myjson.encode("utf-8") - def _output_json(document): - myjson = json.dumps(document, ensure_ascii=False) - return myjson.encode('utf-8') + def _name(start_index, npad, prefix, i): + tag = str(start_index + i).rjust(npad, "0") + return f"{prefix}{tag}" - def _name(start_index, npad, prefix, i): - tag = str(start_index + i).rjust(npad, '0') - return f"{prefix}{tag}" + # Create the naming function + npad = int(np.log10(max_index) + 1) + name = partial(_name, start_index, npad, prefix) - # Create the naming function - npad = int(np.log10(max_index) + 1) - name = partial(_name, start_index, npad, prefix) + output_glob_string = os.path.join(output_path, "*.jsonl") - output_glob_string = os.path.join(output_path, "*.jsonl") + documents.map(_output_json).to_textfiles( + output_glob_string, + name_function=name, + ) - documents.map(_output_json).to_textfiles( - output_glob_string, - name_function=name, - ) -def reshard_jsonl(input_dir, output_dir, output_file_size="100M", start_index=0, file_prefix=""): +def reshard_jsonl( + input_dir, output_dir, output_file_size="100M", start_index=0, file_prefix="" +): """ Reshards a directory of jsonl files to have a new (approximate) file size for each shard @@ -200,4 +228,4 @@ def reshard_jsonl(input_dir, output_dir, output_file_size="100M", start_index=0, output_dir = expand_outdir_and_mkdir(output_dir) # Save to balanced files - _save_jsonl(b, output_dir, start_index=start_index, prefix=file_prefix) \ No newline at end of file + _save_jsonl(b, output_dir, start_index=start_index, prefix=file_prefix) diff --git a/nemo_curator/utils/fuzzy_dedup_utils/__init__.py b/nemo_curator/utils/fuzzy_dedup_utils/__init__.py index fe99e99a..d9155f92 100644 --- a/nemo_curator/utils/fuzzy_dedup_utils/__init__.py +++ b/nemo_curator/utils/fuzzy_dedup_utils/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/nemo_curator/utils/fuzzy_dedup_utils/id_mapping.py b/nemo_curator/utils/fuzzy_dedup_utils/id_mapping.py index a566139e..5f87aa7d 100644 --- a/nemo_curator/utils/fuzzy_dedup_utils/id_mapping.py +++ b/nemo_curator/utils/fuzzy_dedup_utils/id_mapping.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def convert_str_id_to_int(df, id_column="id"): """ Converts the legacy id format "dataset_name-0000034" diff --git a/nemo_curator/utils/fuzzy_dedup_utils/io_utils.py b/nemo_curator/utils/fuzzy_dedup_utils/io_utils.py index 842613a0..cc6e0909 100644 --- a/nemo_curator/utils/fuzzy_dedup_utils/io_utils.py +++ b/nemo_curator/utils/fuzzy_dedup_utils/io_utils.py @@ -76,7 +76,9 @@ def get_text_ddf_from_json_path_with_blocksize( text_ddf = text_ddf.map_partitions( convert_str_id_to_int, id_column=id_column, - meta=cudf.DataFrame({text_column: ["a"], "doc_id": [0], "dataset_id": np.uint32(1)}), + meta=cudf.DataFrame( + {text_column: ["a"], "doc_id": [0], "dataset_id": np.uint32(1)} + ), ) return text_ddf diff --git a/nemo_curator/utils/fuzzy_dedup_utils/output_map_utils.py b/nemo_curator/utils/fuzzy_dedup_utils/output_map_utils.py index 0b686711..6a70e6d2 100644 --- a/nemo_curator/utils/fuzzy_dedup_utils/output_map_utils.py +++ b/nemo_curator/utils/fuzzy_dedup_utils/output_map_utils.py @@ -14,9 +14,11 @@ from __future__ import annotations -import numpy as np -import numba from typing import Tuple + +import numba +import numpy as np + from nemo_curator._compat import DASK_SHUFFLE_METHOD_ARG diff --git a/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py b/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py index 020cab33..e104ee0c 100644 --- a/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py +++ b/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py @@ -17,9 +17,7 @@ import numpy as np from dask import config from dask.dataframe.shuffle import rearrange_by_column -from dask_cuda.explicit_comms.dataframe.shuffle import ( - shuffle as explicit_comms_shuffle, -) +from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle from packaging.version import Version from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import ( diff --git a/nemo_curator/utils/gpu_utils.py b/nemo_curator/utils/gpu_utils.py index c7b1e89b..de1c23df 100644 --- a/nemo_curator/utils/gpu_utils.py +++ b/nemo_curator/utils/gpu_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def is_cudf_type(obj): """ Check if an object is a cuDF type @@ -22,4 +23,3 @@ def is_cudf_type(obj): str(getattr(obj, "_meta", "")), ] return any("cudf" in obj_type for obj_type in types) - diff --git a/nemo_curator/utils/script_utils.py b/nemo_curator/utils/script_utils.py index c005c0fd..8da562d3 100644 --- a/nemo_curator/utils/script_utils.py +++ b/nemo_curator/utils/script_utils.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse +import os from itertools import islice @@ -108,6 +108,7 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> argparse.ArgumentPa return parser + def chunk_list(lst, nchnks): nitem = len(lst) splits = splitnum(nitem, nchnks) diff --git a/nemo_curator/utils/text_utils.py b/nemo_curator/utils/text_utils.py index 4686d113..9b4429f2 100644 --- a/nemo_curator/utils/text_utils.py +++ b/nemo_curator/utils/text_utils.py @@ -12,45 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import ast -import warnings +import os +import string import tokenize -from itertools import groupby +import warnings from io import StringIO -import string +from itertools import groupby def get_word_splitter(language): - language = language.lower() - if language == 'zh': - import jieba - return jieba.cut - else: - return default_splitter + language = language.lower() + if language == "zh": + import jieba + + return jieba.cut + else: + return default_splitter def default_splitter(document): - return document.split() + return document.split() def get_paragraphs(document): - # Split the document into paragraphs. - # A paragraph is defined as a sequence of lines - # separated by a double newline. - return document.split('\n\n') + # Split the document into paragraphs. + # A paragraph is defined as a sequence of lines + # separated by a double newline. + return document.split("\n\n") def get_sentences(document): - # Split the document into sentences. - # A sentence is defined as a sequence of lines separated - # by a single newline. - return [x for x in document.split('\n') if len(x.strip()) > 0] + # Split the document into sentences. + # A sentence is defined as a sequence of lines separated + # by a single newline. + return [x for x in document.split("\n") if len(x.strip()) > 0] def get_ngrams(input_list, n): - # Fast function to return n-grams from a list of tokens. - return [item for item in zip(*[input_list[i:] for i in range(n)])] + # Fast function to return n-grams from a list of tokens. + return [item for item in zip(*[input_list[i:] for i in range(n)])] def is_paragraph_indices_in_top_or_bottom_only( @@ -58,130 +59,134 @@ def is_paragraph_indices_in_top_or_bottom_only( num_paragraphs, ): - def _is_contiguous(indices): + def _is_contiguous(indices): + # Indices are sorted in ascending order. + num_indices = len(indices) - 1 + return all(indices[i] + 1 == indices[i + 1] for i in range(num_indices)) + + # See if the indices are contiguous and exclusively at the top/bottom. # Indices are sorted in ascending order. - num_indices = len(indices) - 1 - return all(indices[i] + 1 == indices[i + 1] for i in range(num_indices)) - - # See if the indices are contiguous and exclusively at the top/bottom. - # Indices are sorted in ascending order. - # If num_paragraphs = 11: - # Valid indices example : [0, 1, 9, 10] - # Invalid indices example : [0, 1, 3, 9, 10] - # Invalid indices example : [0, 1, 3, 5, 6, 9, 10] - # Invalid indices example : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - if len(boilerplate_paragraph_indices) == num_paragraphs: - return False - return _is_contiguous(boilerplate_paragraph_indices) and ( - boilerplate_paragraph_indices[0] == 0 or - boilerplate_paragraph_indices[-1] == num_paragraphs - 1) + # If num_paragraphs = 11: + # Valid indices example : [0, 1, 9, 10] + # Invalid indices example : [0, 1, 3, 9, 10] + # Invalid indices example : [0, 1, 3, 5, 6, 9, 10] + # Invalid indices example : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + if len(boilerplate_paragraph_indices) == num_paragraphs: + return False + return _is_contiguous(boilerplate_paragraph_indices) and ( + boilerplate_paragraph_indices[0] == 0 + or boilerplate_paragraph_indices[-1] == num_paragraphs - 1 + ) # Node types for processing abstract syntax tree NODE_TYPES = { - ast.ClassDef: 'Class', - ast.FunctionDef: 'Function/Method', - ast.Module: 'Module' + ast.ClassDef: "Class", + ast.FunctionDef: "Function/Method", + ast.Module: "Module", } def get_comments_and_docstring(source, comments=True, clean_comments=False): - """ - Extract all natural text in source: comments + doctsrings - the extraction fails in case of syntax errors in the file - Args: - source: the code to parse - comments: if True extract comments two - clean_comment: if True remove # from extracted comments - Returns: - a string with concatenated docstrings and comments - """ - - try: - docstrings = '\n'.join(get_docstrings(source)) - except Exception: - docstrings = None - warnings.warn("code couldn't be parsed due to compilation failure, " - "no docstring is extracted") - - if comments: + """ + Extract all natural text in source: comments + doctsrings + the extraction fails in case of syntax errors in the file + Args: + source: the code to parse + comments: if True extract comments two + clean_comment: if True remove # from extracted comments + Returns: + a string with concatenated docstrings and comments + """ + try: - comments = get_comments(source, clean=clean_comments) + docstrings = "\n".join(get_docstrings(source)) except Exception: - comments = None - warnings.warn("tokenization error, no comment is extracted") - else: - comments = '' + docstrings = None + warnings.warn( + "code couldn't be parsed due to compilation failure, " + "no docstring is extracted" + ) + + if comments: + try: + comments = get_comments(source, clean=clean_comments) + except Exception: + comments = None + warnings.warn("tokenization error, no comment is extracted") + else: + comments = "" - return docstrings, comments + return docstrings, comments def get_comments(s, clean=False): - "Returns a string including all coments" - coments = [] - g = tokenize.generate_tokens(StringIO(s).readline) - for toknum, tokval, _, _, _ in g: - # print(toknum,tokval) - if toknum == tokenize.COMMENT: - coments.append((toknum, tokval)) - result = tokenize.untokenize(coments) - if clean: - result = result.replace('#', '') - return result - - -def get_docstrings(source, module=''): - """Parse Python source code from file or string and print docstrings.""" - if hasattr(source, 'read'): - filename = getattr(source, 'name', module) - module = os.path.splitext(os.path.basename(filename))[0] - source = source.read() - - docstrings = sorted(parse_docstrings(source), - key=lambda x: (NODE_TYPES.get(type(x[0])), x[1])) - - grouped = groupby(docstrings, key=lambda x: NODE_TYPES.get(type(x[0]))) - results = [] - for _, group in grouped: - for _, name, docstring in group: - name = name if name else module - if docstring: - results.append(docstring) - return results + "Returns a string including all coments" + coments = [] + g = tokenize.generate_tokens(StringIO(s).readline) + for toknum, tokval, _, _, _ in g: + # print(toknum,tokval) + if toknum == tokenize.COMMENT: + coments.append((toknum, tokval)) + result = tokenize.untokenize(coments) + if clean: + result = result.replace("#", "") + return result + + +def get_docstrings(source, module=""): + """Parse Python source code from file or string and print docstrings.""" + if hasattr(source, "read"): + filename = getattr(source, "name", module) + module = os.path.splitext(os.path.basename(filename))[0] + source = source.read() + + docstrings = sorted( + parse_docstrings(source), key=lambda x: (NODE_TYPES.get(type(x[0])), x[1]) + ) + + grouped = groupby(docstrings, key=lambda x: NODE_TYPES.get(type(x[0]))) + results = [] + for _, group in grouped: + for _, name, docstring in group: + name = name if name else module + if docstring: + results.append(docstring) + return results def parse_docstrings(source): - """Parse Python source code and yield a tuple of ast node instance, name, + """Parse Python source code and yield a tuple of ast node instance, name, and docstring for each function/method, class and module.""" - tree = ast.parse(source) + tree = ast.parse(source) - for node in ast.walk(tree): - if isinstance(node, tuple(NODE_TYPES)): - docstring = ast.get_docstring(node) + for node in ast.walk(tree): + if isinstance(node, tuple(NODE_TYPES)): + docstring = ast.get_docstring(node) - yield (node, getattr(node, 'name', None), docstring) + yield (node, getattr(node, "name", None), docstring) def remove_punctuation(str_in): - return str_in.translate(str_in.maketrans('', '', string.punctuation)) + return str_in.translate(str_in.maketrans("", "", string.punctuation)) def get_words(text): - word_start_char_positions = [] - prev = 0 - words = [] - - text = text.lower() - text = remove_punctuation(text) - if len(text) > 0: - for i in range(len(text)): - if text[i] != ' ': - if i == 0 or text[i - 1] == ' ': - word_start_char_positions.append(i) - if i != 0: - words.append(text[prev:i].strip()) - prev = i - words.append(text[prev:i + 1].strip()) - if words[0] == '': - words = words[1:] - return words, word_start_char_positions + word_start_char_positions = [] + prev = 0 + words = [] + + text = text.lower() + text = remove_punctuation(text) + if len(text) > 0: + for i in range(len(text)): + if text[i] != " ": + if i == 0 or text[i - 1] == " ": + word_start_char_positions.append(i) + if i != 0: + words.append(text[prev:i].strip()) + prev = i + words.append(text[prev : i + 1].strip()) + if words[0] == "": + words = words[1:] + return words, word_start_char_positions diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..a026c06c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[tool.isort] +profile = "black" # black-compatible +line_length = 88 # should match black parameters +py_version = 310 +extend_skip = ["setup.py"] + +[tool.black] +line_length = 88 + +[tool.pytest.ini_options] +markers = [ + "gpu: marks tests as GPU tests (deselect with '-m \"not gpu\"')" +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 9f07ebda..00000000 --- a/pytest.ini +++ /dev/null @@ -1,4 +0,0 @@ -[pytest] -markers = - gpu: marks tests as GPU tests (deselect with '-m \"not gpu\"') - diff --git a/setup.py b/setup.py index 807f2513..ae515757 100644 --- a/setup.py +++ b/setup.py @@ -17,88 +17,88 @@ here = pathlib.Path(__file__).parent.resolve() -long_description = (here / 'README.md').read_text(encoding='utf-8') +long_description = (here / "README.md").read_text(encoding="utf-8") setup( - name='nemo_curator', - version='0.2.0', - description='Scalable Data Preprocessing Tool for ' - 'Training Large Language Models', + name="nemo_curator", + version="0.2.0", + description="Scalable Data Preprocessing Tool for " + "Training Large Language Models", long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/NVIDIA/NeMo-Curator', - author='Joseph Jennings, Mostofa Patwary, Sandeep Subramanian, ' - 'Shrimai Prabhumoye, Ayush Dattagupta, Vibhu Jawa, Jiwei Liu, Ryan Wolf', - author_email='jjennings@nvidia.com, mpatwary@nvidia.com, ' - 'rywolf@nvidia.com, sprabhumoye@nvidia.com', + long_description_content_type="text/markdown", + url="https://github.com/NVIDIA/NeMo-Curator", + author="Joseph Jennings, Mostofa Patwary, Sandeep Subramanian, " + "Shrimai Prabhumoye, Ayush Dattagupta, Vibhu Jawa, Jiwei Liu, Ryan Wolf", + author_email="jjennings@nvidia.com, mpatwary@nvidia.com, " + "rywolf@nvidia.com, sprabhumoye@nvidia.com", classifiers=[ - 'Development Status :: 3 - Alpha', - 'Programming Language :: Python :: 3', + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", ], packages=find_packages(), - python_requires='>=3.7', + python_requires=">=3.7", install_requires=[ - 'dask[complete]>=2021.7.1', - 'distributed>=2021.7.1', - 'dask-mpi>=2021.11.0', - 'charset_normalizer>=3.1.0', - 'awscli>=1.22.55', - 'fasttext==0.9.2', - 'pycld2==0.41', - 'justext==3.0.0', - 'ftfy==6.1.1', - 'warcio==1.7.4', - 'zstandard==0.18.0', - 'in-place==0.5.0', - 'unidic-lite==1.0.8', - 'jieba==0.42.1', - 'comment_parser', - 'beautifulsoup4', - 'mwparserfromhell @ git+https://github.com/earwig/mwparserfromhell.git@0f89f44', - 'cudf-cu12==23.10.*', - 'dask-cudf-cu12==23.10.*', - 'cugraph-cu12==23.10.*', - 'dask-cuda==23.10.*', - 'spacy>=3.6.0, <4.0.0', - 'presidio-analyzer==2.2.351', - 'presidio-anonymizer==2.2.351', - 'usaddress==0.5.10', - 'nemo_toolkit[nlp]>=1.23.0' + "dask[complete]>=2021.7.1", + "distributed>=2021.7.1", + "dask-mpi>=2021.11.0", + "charset_normalizer>=3.1.0", + "awscli>=1.22.55", + "fasttext==0.9.2", + "pycld2==0.41", + "justext==3.0.0", + "ftfy==6.1.1", + "warcio==1.7.4", + "zstandard==0.18.0", + "in-place==0.5.0", + "unidic-lite==1.0.8", + "jieba==0.42.1", + "comment_parser", + "beautifulsoup4", + "mwparserfromhell @ git+https://github.com/earwig/mwparserfromhell.git@0f89f44", + "cudf-cu12==23.10.*", + "dask-cudf-cu12==23.10.*", + "cugraph-cu12==23.10.*", + "dask-cuda==23.10.*", + "spacy>=3.6.0, <4.0.0", + "presidio-analyzer==2.2.351", + "presidio-anonymizer==2.2.351", + "usaddress==0.5.10", + "nemo_toolkit[nlp]>=1.23.0", ], entry_points={ - 'console_scripts': [ - 'get_common_crawl_urls=nemo_curator.scripts.get_common_crawl_urls:console_script', - 'get_wikipedia_urls=nemo_curator.scripts.get_wikipedia_urls:console_script', - 'download_and_extract=nemo_curator.scripts.download_and_extract:console_script', - 'text_cleaning=nemo_curator.scripts.text_cleaning:console_script', - 'add_id=nemo_curator.scripts.add_id:console_script', - 'get_metadata_from_corpus=nemo_curator.get_metadata_from_corpus:console_script', - 'make_data_shards=nemo_curator.scripts.make_data_shards:console_script', - 'prepare_fasttext_training_data=nemo_curator.scripts.prepare_fasttext_training_data:console_script', - 'train_fasttext=nemo_curator.scripts.train_fasttext:console_script', - 'filter_documents=nemo_curator.scripts.filter_documents:console_script', - 'separate_by_metadata=nemo_curator.scripts.separate_by_metadata:console_script', - 'prepare_task_data=nemo_curator.scripts.prepare_task_data:console_script', - 'find_matching_ngrams=nemo_curator.scripts.find_matching_ngrams:console_script', - 'remove_matching_ngrams=nemo_curator.scripts.remove_matching_ngrams:console_script', - 'gpu_compute_minhashes=nemo_curator.scripts.compute_minhashes:console_script', - 'minhash_buckets=nemo_curator.scripts.minhash_lsh:console_script', - 'jaccard_map_buckets=nemo_curator.scripts.map_buckets:console_script', - 'jaccard_shuffle=nemo_curator.scripts.jaccard_shuffle:console_script', - 'jaccard_compute=nemo_curator.scripts.jaccard_compute:console_script', - 'gpu_connected_component=nemo_curator.scripts.connected_components:console_script', - 'write_deduped_result_with_text=nemo_curator.gpu_deduplication.write_deduped_result_with_text:console_script', - 'verify_all_pairs_jaccard=nemo_curator.gpu_deduplication.verify_all_pairs_jaccard:console_script', - 'gpu_exact_dups=nemo_curator.scripts.find_exact_duplicates:console_script', - 'prepare_fuzzy_ids=nemo_curator.gpu_deduplication.prepare_fuzzy_ids:console_script', - 'create_list_of_duplicate_ids=nemo_curator.gpu_deduplication.create_list_of_duplicate_ids:console_script', - 'remove_duplicates=nemo_curator.gpu_deduplication.remove_duplicates:console_script', - 'deidentify=nemo_curator.scripts.find_pii_and_deidentify:console_script', - 'generate_statistics=nemo_curator.distributed_data_classification.generate_statistics:console_script', - 'domain_classifier_inference=nemo_curator.distributed_data_classification.domain_classifier_inference:console_script', - 'quality_classifier_multiple_models_inference=nemo_curator.distributed_data_classification.quality_classifier_multiple_models_inference:console_script', - 'quality_classifier_inference=nemo_curator.distributed_data_classification.quality_classifier_inference:console_script', - 'verify_results=nemo_curator.distributed_data_classification.verify_results:console_script', + "console_scripts": [ + "get_common_crawl_urls=nemo_curator.scripts.get_common_crawl_urls:console_script", + "get_wikipedia_urls=nemo_curator.scripts.get_wikipedia_urls:console_script", + "download_and_extract=nemo_curator.scripts.download_and_extract:console_script", + "text_cleaning=nemo_curator.scripts.text_cleaning:console_script", + "add_id=nemo_curator.scripts.add_id:console_script", + "get_metadata_from_corpus=nemo_curator.get_metadata_from_corpus:console_script", + "make_data_shards=nemo_curator.scripts.make_data_shards:console_script", + "prepare_fasttext_training_data=nemo_curator.scripts.prepare_fasttext_training_data:console_script", + "train_fasttext=nemo_curator.scripts.train_fasttext:console_script", + "filter_documents=nemo_curator.scripts.filter_documents:console_script", + "separate_by_metadata=nemo_curator.scripts.separate_by_metadata:console_script", + "prepare_task_data=nemo_curator.scripts.prepare_task_data:console_script", + "find_matching_ngrams=nemo_curator.scripts.find_matching_ngrams:console_script", + "remove_matching_ngrams=nemo_curator.scripts.remove_matching_ngrams:console_script", + "gpu_compute_minhashes=nemo_curator.scripts.compute_minhashes:console_script", + "minhash_buckets=nemo_curator.scripts.minhash_lsh:console_script", + "jaccard_map_buckets=nemo_curator.scripts.map_buckets:console_script", + "jaccard_shuffle=nemo_curator.scripts.jaccard_shuffle:console_script", + "jaccard_compute=nemo_curator.scripts.jaccard_compute:console_script", + "gpu_connected_component=nemo_curator.scripts.connected_components:console_script", + "write_deduped_result_with_text=nemo_curator.gpu_deduplication.write_deduped_result_with_text:console_script", + "verify_all_pairs_jaccard=nemo_curator.gpu_deduplication.verify_all_pairs_jaccard:console_script", + "gpu_exact_dups=nemo_curator.scripts.find_exact_duplicates:console_script", + "prepare_fuzzy_ids=nemo_curator.gpu_deduplication.prepare_fuzzy_ids:console_script", + "create_list_of_duplicate_ids=nemo_curator.gpu_deduplication.create_list_of_duplicate_ids:console_script", + "remove_duplicates=nemo_curator.gpu_deduplication.remove_duplicates:console_script", + "deidentify=nemo_curator.scripts.find_pii_and_deidentify:console_script", + "generate_statistics=nemo_curator.distributed_data_classification.generate_statistics:console_script", + "domain_classifier_inference=nemo_curator.distributed_data_classification.domain_classifier_inference:console_script", + "quality_classifier_multiple_models_inference=nemo_curator.distributed_data_classification.quality_classifier_multiple_models_inference:console_script", + "quality_classifier_inference=nemo_curator.distributed_data_classification.quality_classifier_inference:console_script", + "verify_results=nemo_curator.distributed_data_classification.verify_results:console_script", ], }, ) diff --git a/tests/__init__.py b/tests/__init__.py index fe99e99a..d9155f92 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/tests/pii_data/address.txt b/tests/pii_data/address.txt index 6642380b..f583b06e 100644 --- a/tests/pii_data/address.txt +++ b/tests/pii_data/address.txt @@ -47,4 +47,4 @@ He curently resides at
30 Memorial Drive, Avon MA 2322
101 Sanford Farm Shpg Center, Amsterdam NY 12010
297 Grant Avenue, Auburn NY 13021
4133 Veterans Memorial Drive, Batavia NY 14020
-
6265 Brockport Spencerport Rd, Brockport NY 14420
\ No newline at end of file +
6265 Brockport Spencerport Rd, Brockport NY 14420
diff --git a/tests/pii_data/birthdates.txt b/tests/pii_data/birthdates.txt index 87c9a8ef..958ba281 100644 --- a/tests/pii_data/birthdates.txt +++ b/tests/pii_data/birthdates.txt @@ -1 +1 @@ -I was born on December 5, 1983. \ No newline at end of file +I was born on December 5, 1983. diff --git a/tests/pii_data/card_no.txt b/tests/pii_data/card_no.txt index ae2a039e..48afc272 100644 --- a/tests/pii_data/card_no.txt +++ b/tests/pii_data/card_no.txt @@ -1,3 +1,3 @@ 4929-3813-3266-4295 345389698201044 -I have changed my credit card. My new credit card number is 5299-1561-5689-1938. \ No newline at end of file +I have changed my credit card. My new credit card number is 5299-1561-5689-1938. diff --git a/tests/pii_data/emails.txt b/tests/pii_data/emails.txt index c2ed5e58..6d238517 100644 --- a/tests/pii_data/emails.txt +++ b/tests/pii_data/emails.txt @@ -1,3 +1,3 @@ Hello, I am John. My email is johndoe@openai.com. Feel free to drop me an email. testbotaccount@nvidia.com -testbotaccount@gmail.com \ No newline at end of file +testbotaccount@gmail.com diff --git a/tests/pii_data/ip_address.txt b/tests/pii_data/ip_address.txt index 33b19ff2..823393aa 100644 --- a/tests/pii_data/ip_address.txt +++ b/tests/pii_data/ip_address.txt @@ -1,2 +1,2 @@ inet 10.41.21.72 -Can you ping me on my ip 192.168.0.12? \ No newline at end of file +Can you ping me on my ip 192.168.0.12? diff --git a/tests/pii_data/multiple.txt b/tests/pii_data/multiple.txt index 41afb963..ced2b754 100644 --- a/tests/pii_data/multiple.txt +++ b/tests/pii_data/multiple.txt @@ -1,2 +1,2 @@ Hello, I am David. I was born on December 5, 1983. My email is david@company.com and you can call me on (845) 450 5693 -My SSN is 284-89-3485. I live at
2456 Main Street, Silver Spring MD
\ No newline at end of file +My SSN is 284-89-3485. I live at
2456 Main Street, Silver Spring MD
diff --git a/tests/pii_data/names.txt b/tests/pii_data/names.txt index 292be921..7c216643 100644 --- a/tests/pii_data/names.txt +++ b/tests/pii_data/names.txt @@ -1,4 +1,4 @@ Surprisingly, John, Shaun, and Ron finished their exam tests all at the same time. Anton Perera Sameer Rawat -Ilya Sutskever \ No newline at end of file +Ilya Sutskever diff --git a/tests/pii_data/phone_numbers.txt b/tests/pii_data/phone_numbers.txt index b989a26b..1ed303f0 100644 --- a/tests/pii_data/phone_numbers.txt +++ b/tests/pii_data/phone_numbers.txt @@ -2,4 +2,4 @@ You can call me at (814) 360 8593 +1-212-456-7890 +12312322334 2312322334 -(231) 232-2334 \ No newline at end of file +(231) 232-2334 diff --git a/tests/pii_data/ssn.txt b/tests/pii_data/ssn.txt index d636a2e5..92626a0c 100644 --- a/tests/pii_data/ssn.txt +++ b/tests/pii_data/ssn.txt @@ -1,2 +1,2 @@ 514-14-8905 -My SSN is 284-89-3485. \ No newline at end of file +My SSN is 284-89-3485. diff --git a/tests/test_add_id.py b/tests/test_add_id.py index 08382f70..458b4868 100644 --- a/tests/test_add_id.py +++ b/tests/test_add_id.py @@ -12,26 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pandas as pd import dask.dataframe as dd +import pandas as pd import pytest import nemo_curator from nemo_curator.datasets import DocumentDataset + def list_to_dataset(documents, col_name="text", npartitions=2): data = {col_name: documents} pdf = pd.DataFrame(data) - + return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + @pytest.fixture def single_partition_dataset(): - return list_to_dataset(["First", "Second", "Third", "Fourth", "Fifth"], npartitions=1) + return list_to_dataset( + ["First", "Second", "Third", "Fourth", "Fifth"], npartitions=1 + ) + @pytest.fixture def two_partition_dataset(): - return list_to_dataset(["First", "Second", "Third", "Fourth", "Fifth"], npartitions=2) + return list_to_dataset( + ["First", "Second", "Third", "Fourth", "Fifth"], npartitions=2 + ) + class TestPrepareTaskData: def test_basic_id(self, single_partition_dataset): @@ -39,35 +47,75 @@ def test_basic_id(self, single_partition_dataset): add_id = nemo_curator.AddId(id_field) id_dataset = add_id(single_partition_dataset) actual_ids = id_dataset.df[id_field].compute() - expected_ids = pd.Series(["doc_id-0000000000", "doc_id-0000000001", "doc_id-0000000002", "doc_id-0000000003", "doc_id-0000000004"]) + expected_ids = pd.Series( + [ + "doc_id-0000000000", + "doc_id-0000000001", + "doc_id-0000000002", + "doc_id-0000000003", + "doc_id-0000000004", + ] + ) + + assert all( + expected_ids == actual_ids + ), f"Expected: {expected_ids}, got: {actual_ids}" - assert all(expected_ids == actual_ids), f"Expected: {expected_ids}, got: {actual_ids}" - def test_two_partitions(self, two_partition_dataset): id_field = "id" add_id = nemo_curator.AddId(id_field) id_dataset = add_id(two_partition_dataset) actual_ids = id_dataset.df[id_field].compute() - expected_ids = pd.Series(["doc_id-0000000000", "doc_id-0000000001", "doc_id-0000000002", "doc_id-0000000003", "doc_id-0000000004"]) + expected_ids = pd.Series( + [ + "doc_id-0000000000", + "doc_id-0000000001", + "doc_id-0000000002", + "doc_id-0000000003", + "doc_id-0000000004", + ] + ) + + assert all( + expected_ids == actual_ids + ), f"Expected: {expected_ids}, got: {actual_ids}" - assert all(expected_ids == actual_ids), f"Expected: {expected_ids}, got: {actual_ids}" - def test_id_prefix(self, two_partition_dataset): id_field = "id" id_prefix = "my_id" add_id = nemo_curator.AddId(id_field, id_prefix=id_prefix) id_dataset = add_id(two_partition_dataset) actual_ids = id_dataset.df[id_field].compute() - expected_ids = pd.Series([f"{id_prefix}-0000000000", f"{id_prefix}-0000000001", f"{id_prefix}-0000000002", f"{id_prefix}-0000000003", f"{id_prefix}-0000000004"]) + expected_ids = pd.Series( + [ + f"{id_prefix}-0000000000", + f"{id_prefix}-0000000001", + f"{id_prefix}-0000000002", + f"{id_prefix}-0000000003", + f"{id_prefix}-0000000004", + ] + ) + + assert all( + expected_ids == actual_ids + ), f"Expected: {expected_ids}, got: {actual_ids}" - assert all(expected_ids == actual_ids), f"Expected: {expected_ids}, got: {actual_ids}" - def test_start_index(self, two_partition_dataset): id_field = "id" start_index = 13 add_id = nemo_curator.AddId(id_field, start_index=start_index) id_dataset = add_id(two_partition_dataset) actual_ids = id_dataset.df[id_field].compute() - expected_ids = pd.Series(["doc_id-0000000013", "doc_id-0000000014", "doc_id-0000000015", "doc_id-0000000016", "doc_id-0000000017"]) + expected_ids = pd.Series( + [ + "doc_id-0000000013", + "doc_id-0000000014", + "doc_id-0000000015", + "doc_id-0000000016", + "doc_id-0000000017", + ] + ) - assert all(expected_ids == actual_ids), f"Expected: {expected_ids}, got: {actual_ids}" \ No newline at end of file + assert all( + expected_ids == actual_ids + ), f"Expected: {expected_ids}, got: {actual_ids}" diff --git a/tests/test_exact_dedup.py b/tests/test_exact_dedup.py index 827253fc..d0408073 100644 --- a/tests/test_exact_dedup.py +++ b/tests/test_exact_dedup.py @@ -16,6 +16,7 @@ import pytest from dask import dataframe as dd from dask.dataframe.utils import assert_eq + from nemo_curator.datasets import DocumentDataset from nemo_curator.modules import ExactDuplicates diff --git a/tests/test_filters.py b/tests/test_filters.py index 77fa27bd..bd6f0e63 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -13,20 +13,53 @@ # limitations under the License. import os -import pandas as pd -from dask import dataframe as dd +import pandas as pd import pytest +from dask import dataframe as dd from nemo_curator.datasets import DocumentDataset -from nemo_curator.modules import ScoreFilter, Score, Filter, Sequential -from nemo_curator.filters import DocumentFilter, NonAlphaNumericFilter, SymbolsToWordsFilter, NumbersFilter, UrlsFilter, BulletsFilter, WhiteSpaceFilter, ParenthesesFilter, LongWordFilter, WordCountFilter, BoilerPlateStringFilter, MeanWordLengthFilter, RepeatedLinesFilter, RepeatedParagraphsFilter, RepeatedLinesByCharFilter, RepeatedParagraphsByCharFilter, RepeatingTopNGramsFilter, RepeatingDuplicateNGramsFilter, PunctuationFilter, EllipsisFilter, CommonEnglishWordsFilter, WordsWithoutAlphabetsFilter, PornographicUrlsFilter -from nemo_curator.filters import PythonCommentToCodeFilter, GeneralCommentToCodeFilter, NumberOfLinesOfCodeFilter, TokenizerFertilityFilter, XMLHeaderFilter, AlphaFilter, HTMLBoilerplateFilter, PerExtensionFilter +from nemo_curator.filters import ( + AlphaFilter, + BoilerPlateStringFilter, + BulletsFilter, + CommonEnglishWordsFilter, + DocumentFilter, + EllipsisFilter, + GeneralCommentToCodeFilter, + HTMLBoilerplateFilter, + LongWordFilter, + MeanWordLengthFilter, + NonAlphaNumericFilter, + NumberOfLinesOfCodeFilter, + NumbersFilter, + ParenthesesFilter, + PerExtensionFilter, + PornographicUrlsFilter, + PunctuationFilter, + PythonCommentToCodeFilter, + RepeatedLinesByCharFilter, + RepeatedLinesFilter, + RepeatedParagraphsByCharFilter, + RepeatedParagraphsFilter, + RepeatingDuplicateNGramsFilter, + RepeatingTopNGramsFilter, + SymbolsToWordsFilter, + TokenizerFertilityFilter, + UrlsFilter, + WhiteSpaceFilter, + WordCountFilter, + WordsWithoutAlphabetsFilter, + XMLHeaderFilter, +) +from nemo_curator.modules import Filter, Score, ScoreFilter, Sequential + class LetterCountFilter(DocumentFilter): """ Keeps documents that have at least some number of a given letter """ + def __init__(self, letter="a", min_count=5): super().__init__() self.letter = letter @@ -34,19 +67,21 @@ def __init__(self, letter="a", min_count=5): def score_document(self, text): return text.count(self.letter) - + def keep_document(self, score): return score >= self.min_count + class BatchedLengthFilter(DocumentFilter): """ Keeps documents of a given length """ + def __init__(self, min_length=5, max_length=10): super().__init__() self.min_length = min_length self.max_length = max_length - + def score_document(self, df): return df.str.len() @@ -55,18 +90,24 @@ def keep_document(self, scores): max_threshold = scores <= self.max_length return min_threshold & max_threshold + def all_equal(left_dataset, right_dataset): return all(left_dataset.df.compute() == right_dataset.df.compute()) + def list_to_dataset(documents, col_name="text", npartitions=2): data = {col_name: documents} pdf = pd.DataFrame(data) - + return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + @pytest.fixture def letter_count_data(): - return list_to_dataset(["Two aa", "a a Three a", "Five aaa aa", "aaaSeven aaaa"], col_name="documents") + return list_to_dataset( + ["Two aa", "a a Three a", "Five aaa aa", "aaaSeven aaaa"], col_name="documents" + ) + class TestFilterModule: def test_score_filter(self, letter_count_data): @@ -76,33 +117,49 @@ def test_score_filter(self, letter_count_data): expected_indices = [2, 3] expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_score(self, letter_count_data): letter_filter = LetterCountFilter() score_field = "a_count" - score_step = Score(letter_filter.score_document, text_field="documents", score_field=score_field) + score_step = Score( + letter_filter.score_document, + text_field="documents", + score_field=score_field, + ) scored_data = score_step(letter_count_data) expected_scores = pd.Series([2, 3, 5, 7]) scores = scored_data.df[score_field] - assert all(expected_scores == scores.compute()), f"Expected {expected_scores} but got {scores}" + assert all( + expected_scores == scores.compute() + ), f"Expected {expected_scores} but got {scores}" def test_retain_score_filter(self, letter_count_data): letter_filter = LetterCountFilter() score_field = "count_a" - filter_step = ScoreFilter(letter_filter, text_field="documents", score_field=score_field) + filter_step = ScoreFilter( + letter_filter, text_field="documents", score_field=score_field + ) filtered_data = filter_step(letter_count_data) expected_indices = [2, 3] expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) expected_data.df[score_field] = pd.Series([5, 7], index=expected_data.df.index) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" def test_filter(self, letter_count_data): letter_filter = LetterCountFilter() score_field = "a_count" - score_step = Score(letter_filter.score_document, text_field="documents", score_field=score_field) + score_step = Score( + letter_filter.score_document, + text_field="documents", + score_field=score_field, + ) scored_data = score_step(letter_count_data) filter_step = Filter(letter_filter.keep_document, score_field) filtered_data = filter_step(scored_data) @@ -111,7 +168,9 @@ def test_filter(self, letter_count_data): expected_data = letter_count_data.df.loc[expected_indices] expected_data[score_field] = pd.Series([5, 7], index=expected_data.index) expected_data = DocumentDataset(expected_data) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" def test_invert(self, letter_count_data): letter_filter = LetterCountFilter() @@ -120,18 +179,24 @@ def test_invert(self, letter_count_data): expected_indices = [0, 1] expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" def test_sequential_filter(self, letter_count_data): - filters = Sequential([ - ScoreFilter(LetterCountFilter(), text_field="documents"), - ScoreFilter(LetterCountFilter(min_count=6), text_field="documents") - ]) + filters = Sequential( + [ + ScoreFilter(LetterCountFilter(), text_field="documents"), + ScoreFilter(LetterCountFilter(min_count=6), text_field="documents"), + ] + ) filtered_data = filters(letter_count_data) expected_indices = [3] expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" def test_batch_score_filter(self, letter_count_data): length_filter = BatchedLengthFilter(min_length=8, max_length=11) @@ -140,22 +205,36 @@ def test_batch_score_filter(self, letter_count_data): expected_indices = [1, 2] expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_batch_score(self, letter_count_data): length_filter = BatchedLengthFilter(min_length=8, max_length=11) score_field = "lengths" - score_step = Score(length_filter.score_document, text_field="documents", score_field=score_field, batched=True) + score_step = Score( + length_filter.score_document, + text_field="documents", + score_field=score_field, + batched=True, + ) scored_data = score_step(letter_count_data) expected_scores = pd.Series([6, 11, 11, 13]) scores = scored_data.df[score_field] - assert all(expected_scores == scores.compute()), f"Expected {expected_scores} but got {scores}" - + assert all( + expected_scores == scores.compute() + ), f"Expected {expected_scores} but got {scores}" + def test_batch_filter(self, letter_count_data): length_filter = BatchedLengthFilter(min_length=8, max_length=11) score_field = "lengths" - score_step = Score(length_filter.score_document, text_field="documents", score_field=score_field, batched=True) + score_step = Score( + length_filter.score_document, + text_field="documents", + score_field=score_field, + batched=True, + ) scored_data = score_step(letter_count_data) filter_step = Filter(length_filter.keep_document, score_field, batched=True) filtered_data = filter_step(scored_data) @@ -164,8 +243,10 @@ def test_batch_filter(self, letter_count_data): expected_data = letter_count_data.df.loc[expected_indices] expected_data[score_field] = pd.Series([11, 11], index=expected_data.index) expected_data = DocumentDataset(expected_data) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_score_filter_type(self, letter_count_data): letter_filter = LetterCountFilter() filter_step = ScoreFilter(letter_filter, text_field="documents", score_type=int) @@ -173,65 +254,110 @@ def test_score_filter_type(self, letter_count_data): expected_indices = [2, 3] expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_score_type(self, letter_count_data): letter_filter = LetterCountFilter() score_field = "a_count" - score_step = Score(letter_filter.score_document, text_field="documents", score_field=score_field, score_type=int) + score_step = Score( + letter_filter.score_document, + text_field="documents", + score_field=score_field, + score_type=int, + ) scored_data = score_step(letter_count_data) expected_scores = pd.Series([2, 3, 5, 7]) scores = scored_data.df[score_field] - assert all(expected_scores == scores.compute()), f"Expected {expected_scores} but got {scores}" + assert all( + expected_scores == scores.compute() + ), f"Expected {expected_scores} but got {scores}" class TestHeuristicFilters: def test_nonalpha(self): - dataset = list_to_dataset(["", "This is a test case.", "%$^%$^%$&^$()))))", "$aaa"]) + dataset = list_to_dataset( + ["", "This is a test case.", "%$^%$^%$&^$()))))", "$aaa"] + ) filters = ScoreFilter(NonAlphaNumericFilter()) filtered_data = filters(dataset) expected_indices = [1, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_symbolswords(self): - dataset = list_to_dataset(["mixed bag ... #", "full of words", "... # ... # #", "barely ok 3 4 5 6 7 8 9 #"]) + dataset = list_to_dataset( + [ + "mixed bag ... #", + "full of words", + "... # ... # #", + "barely ok 3 4 5 6 7 8 9 #", + ] + ) filters = ScoreFilter(SymbolsToWordsFilter()) filtered_data = filters(dataset) expected_indices = [1, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_numbers(self): - dataset = list_to_dataset(["purely letters", "34134543", "$!@$@!$!@", "abcdefghi1"]) + dataset = list_to_dataset( + ["purely letters", "34134543", "$!@$@!$!@", "abcdefghi1"] + ) filters = ScoreFilter(NumbersFilter(max_number_to_text_ratio=0.1)) filtered_data = filters(dataset) expected_indices = [0, 2, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_urls(self): - dataset = list_to_dataset(["https://www.nvidia.com/en-us/", "no urls here!", "$!@$@!$!@", "bunch of other words with url afdsjafidsaofjbwreowihfdsafbdashuoiotauhiofdafdsafd fdasfdafdsafdsafdsafdsafdsafdsa https://www.nvidia.com/en-us/ something else after the url etc more and more", "words with url https://www.nvidia.com/en-us/"]) + dataset = list_to_dataset( + [ + "https://www.nvidia.com/en-us/", + "no urls here!", + "$!@$@!$!@", + "bunch of other words with url afdsjafidsaofjbwreowihfdsafbdashuoiotauhiofdafdsafd fdasfdafdsafdsafdsafdsafdsafdsa https://www.nvidia.com/en-us/ something else after the url etc more and more", + "words with url https://www.nvidia.com/en-us/", + ] + ) filters = ScoreFilter(UrlsFilter()) filtered_data = filters(dataset) expected_indices = [1, 2, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_bullets(self): - dataset = list_to_dataset(["• not good", "good", "50 \n ⦾ 50", "⁌ this \n⁌ should \n⁌barely \n⁌pass \n⁌5 \n⁌6 \n⁌7 \n⁌8 \n⁌9 \n done!"]) + dataset = list_to_dataset( + [ + "• not good", + "good", + "50 \n ⦾ 50", + "⁌ this \n⁌ should \n⁌barely \n⁌pass \n⁌5 \n⁌6 \n⁌7 \n⁌8 \n⁌9 \n done!", + ] + ) filters = ScoreFilter(BulletsFilter()) filtered_data = filters(dataset) expected_indices = [1, 2, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_whitespace(self): dataset = list_to_dataset(["\t\n\r", "good", "50%\n\n\n", "123\b"]) filters = ScoreFilter(WhiteSpaceFilter()) @@ -239,17 +365,23 @@ def test_whitespace(self): expected_indices = [1, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_parentheses(self): - dataset = list_to_dataset(["()", "(not good)", "this is completely absolutely fine", "123456789("]) + dataset = list_to_dataset( + ["()", "(not good)", "this is completely absolutely fine", "123456789("] + ) filters = ScoreFilter(ParenthesesFilter()) filtered_data = filters(dataset) expected_indices = [2, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_longword(self): dataset = list_to_dataset(["tiny", "large"]) filters = ScoreFilter(LongWordFilter(max_word_length=4)) @@ -257,35 +389,59 @@ def test_longword(self): expected_indices = [0] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_wordcount(self): - dataset = list_to_dataset(["", "one", "two words", "$#@$ %$@$#@ !#@!", "one two three four five"]) + dataset = list_to_dataset( + ["", "one", "two words", "$#@$ %$@$#@ !#@!", "one two three four five"] + ) filters = ScoreFilter(WordCountFilter(min_words=2, max_words=4)) filtered_data = filters(dataset) expected_indices = [2, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_boilerplate(self): - dataset = list_to_dataset(["nothing\t here", "1\n\n2\n\n3\n\n4\n\n5\n\n6\n\nterms of use\n\n privacy policy\n\n cookie policy\n\nuses cookies", "too much \n\n privacy & cookies policy"]) + dataset = list_to_dataset( + [ + "nothing\t here", + "1\n\n2\n\n3\n\n4\n\n5\n\n6\n\nterms of use\n\n privacy policy\n\n cookie policy\n\nuses cookies", + "too much \n\n privacy & cookies policy", + ] + ) filters = ScoreFilter(BoilerPlateStringFilter()) filtered_data = filters(dataset) expected_indices = [0, 1] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_meanwordlength(self): - dataset = list_to_dataset(["a", "aa", "superlongword short", "evenly balanced", "waytoolongforasingleword"]) + dataset = list_to_dataset( + [ + "a", + "aa", + "superlongword short", + "evenly balanced", + "waytoolongforasingleword", + ] + ) filters = ScoreFilter(MeanWordLengthFilter()) filtered_data = filters(dataset) expected_indices = [2, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_repeatedlines(self): dataset = list_to_dataset(["totally unique", "half.\nhalf."]) filters = ScoreFilter(RepeatedLinesFilter()) @@ -293,8 +449,10 @@ def test_repeatedlines(self): expected_indices = [0] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_repeatedparagraphs(self): dataset = list_to_dataset(["totally unique", "half.\n\nhalf."]) filters = ScoreFilter(RepeatedParagraphsFilter()) @@ -302,62 +460,107 @@ def test_repeatedparagraphs(self): expected_indices = [0] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_repeatedlineschar(self): - dataset = list_to_dataset(["totally unique", "a.\na.\nvery very very short duplicate.", "half.\nhalf.", "super very incredibly huge long duplicate.\nsuper very incredibly huge long duplicate.\na.\nb.\nc."]) + dataset = list_to_dataset( + [ + "totally unique", + "a.\na.\nvery very very short duplicate.", + "half.\nhalf.", + "super very incredibly huge long duplicate.\nsuper very incredibly huge long duplicate.\na.\nb.\nc.", + ] + ) filters = ScoreFilter(RepeatedLinesByCharFilter()) filtered_data = filters(dataset) expected_indices = [0, 1] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_repeatedparagraphschar(self): - dataset = list_to_dataset(["totally unique", "a.\n\n a.\n\n very very very short duplicate.", "half.\n\nhalf.", "super very incredibly huge long duplicate.\n\nsuper very incredibly huge long duplicate.\n\n a.\n\n b.\n\n c."]) + dataset = list_to_dataset( + [ + "totally unique", + "a.\n\n a.\n\n very very very short duplicate.", + "half.\n\nhalf.", + "super very incredibly huge long duplicate.\n\nsuper very incredibly huge long duplicate.\n\n a.\n\n b.\n\n c.", + ] + ) filters = ScoreFilter(RepeatedParagraphsByCharFilter()) filtered_data = filters(dataset) expected_indices = [0, 1] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_repeatingtopngrams(self): - dataset = list_to_dataset(["this is a totally fine sentence with no repeating ngrams so we are ok", "a b . a b", "a a a a a a", "totally fine small dupe a b a b"]) + dataset = list_to_dataset( + [ + "this is a totally fine sentence with no repeating ngrams so we are ok", + "a b . a b", + "a a a a a a", + "totally fine small dupe a b a b", + ] + ) filters = ScoreFilter(RepeatingTopNGramsFilter()) filtered_data = filters(dataset) expected_indices = [0, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_repeatingduplicatengrams(self): - dataset = list_to_dataset(["a a b b a a b b", "totally fine", "a a a a this should be fine as well"]) + dataset = list_to_dataset( + ["a a b b a a b b", "totally fine", "a a a a this should be fine as well"] + ) filters = ScoreFilter(RepeatingDuplicateNGramsFilter()) filtered_data = filters(dataset) expected_indices = [1, 2] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_punctuation(self): - dataset = list_to_dataset(["not good", "good.", "just\n barely\n fine\n ok\n yep."]) - filters = ScoreFilter(PunctuationFilter(max_num_sentences_without_endmark_ratio=0.8)) + dataset = list_to_dataset( + ["not good", "good.", "just\n barely\n fine\n ok\n yep."] + ) + filters = ScoreFilter( + PunctuationFilter(max_num_sentences_without_endmark_ratio=0.8) + ) filtered_data = filters(dataset) expected_indices = [1, 2] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_ellipsis(self): - dataset = list_to_dataset(["not good...", "good.", "just...\n barely...\n fine...\n ok...\n yep."]) - filters = ScoreFilter(EllipsisFilter(max_num_lines_ending_with_ellipsis_ratio=0.8)) + dataset = list_to_dataset( + ["not good...", "good.", "just...\n barely...\n fine...\n ok...\n yep."] + ) + filters = ScoreFilter( + EllipsisFilter(max_num_lines_ending_with_ellipsis_ratio=0.8) + ) filtered_data = filters(dataset) expected_indices = [1, 2] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_commonenglishwords(self): dataset = list_to_dataset(["uncommon", "the and", "the and and of to"]) filters = ScoreFilter(CommonEnglishWordsFilter()) @@ -365,8 +568,10 @@ def test_commonenglishwords(self): expected_indices = [1, 2] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_wordswithoutalphabets(self): dataset = list_to_dataset(["totally fine", "good good good good !", "@"]) filters = ScoreFilter(WordsWithoutAlphabetsFilter()) @@ -374,16 +579,26 @@ def test_wordswithoutalphabets(self): expected_indices = [0, 1] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_pornographicurls(self): - dataset = list_to_dataset(["no url", "fine url https://www.nvidia.com/en-us/", "bad url https://www.pornhub.com/"]) + dataset = list_to_dataset( + [ + "no url", + "fine url https://www.nvidia.com/en-us/", + "bad url https://www.pornhub.com/", + ] + ) filters = ScoreFilter(PornographicUrlsFilter()) filtered_data = filters(dataset) expected_indices = [0, 1] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" class TestCodeFilters: @@ -398,21 +613,25 @@ def test_python_comment_to_code(self): expected_indices = [0, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_general_commment_to_code(self): - doc_1 = "// Good code\nprintf(\"hello world\\n\")" - doc_2 = "printf(\"bad code\\n\")" + doc_1 = '// Good code\nprintf("hello world\\n")' + doc_2 = 'printf("bad code\\n")' doc_3 = "// Way far too many\n// comments!" - doc_4 = "/*\nGood comment\n*/\nprintf(\"hello world\\n\")" + doc_4 = '/*\nGood comment\n*/\nprintf("hello world\\n")' dataset = list_to_dataset([doc_1, doc_2, doc_3, doc_4]) filters = ScoreFilter(GeneralCommentToCodeFilter("text/x-c++")) filtered_data = filters(dataset) expected_indices = [0, 3] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_number_lines_code(self): doc_1 = """print("too short")""" doc_2 = """print("just") @@ -427,17 +646,23 @@ def test_number_lines_code(self): expected_indices = [1] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" def test_xml_header(self): - dataset = list_to_dataset(["no header", "", "slightly offset ", "slightly offset ?$#@!", "mixed <>"]) filters = ScoreFilter(AlphaFilter()) @@ -445,8 +670,10 @@ def test_alpha(self): expected_indices = [0, 2] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_html_boilerplate(self): good_doc = """ @@ -489,8 +716,10 @@ def test_html_boilerplate(self): expected_indices = [0] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" - + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" + def test_per_extension_filter(self): good_cpp = """ #include @@ -503,10 +732,22 @@ def test_per_extension_filter(self): }; """ dataset = list_to_dataset([good_cpp]) - metadata_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "nemo_curator", "utils", "code_meta.csv")) - filters = ScoreFilter(PerExtensionFilter("c++", "cpp", metadata_file=metadata_file)) + metadata_file = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + "..", + "nemo_curator", + "utils", + "code_meta.csv", + ) + ) + filters = ScoreFilter( + PerExtensionFilter("c++", "cpp", metadata_file=metadata_file) + ) filtered_data = filters(dataset) expected_indices = [0] expected_data = DocumentDataset(dataset.df.loc[expected_indices]) - assert all_equal(expected_data, filtered_data), f"Expected {expected_data} but got {filtered_data}" \ No newline at end of file + assert all_equal( + expected_data, filtered_data + ), f"Expected {expected_data} but got {filtered_data}" diff --git a/tests/test_pii_accuracy.py b/tests/test_pii_accuracy.py index 7b3d836e..9431779a 100644 --- a/tests/test_pii_accuracy.py +++ b/tests/test_pii_accuracy.py @@ -32,7 +32,10 @@ def load_test_cases(filename): (re.sub(r"<[^>]*>([^<]*)]*>", r"\1", line)).strip() for line in data ] masked_data = [ - (re.sub(r"(<[^>]*>([^<]*)]*>)", lambda x: "*" * len(x.group(2)), line)).strip() for line in data + ( + re.sub(r"(<[^>]*>([^<]*)]*>)", lambda x: "*" * len(x.group(2)), line) + ).strip() + for line in data ] return list(zip(raw_data, masked_data)) @@ -70,7 +73,7 @@ def test_address(self): def test_ssn(self): generate_single_category_test("US_SSN", "ssn.txt") - + def test_birthdates(self): generate_single_category_test("DATE_TIME", "birthdates.txt") @@ -84,7 +87,7 @@ def test_phone_numbers(self): generate_single_category_test("PHONE_NUMBER", "phone_numbers.txt") def test_multiple(self): - deidentifier = PiiDeidentifier("en", anonymize_action='mask') + deidentifier = PiiDeidentifier("en", anonymize_action="mask") test_data = load_test_cases("multiple.txt") for input, target in test_data: @@ -103,7 +106,7 @@ def test_multiple(self): assert match == True def test_batch_accuracy(self): - deidentifier = PiiDeidentifier("en", anonymize_action='mask') + deidentifier = PiiDeidentifier("en", anonymize_action="mask") test_data = load_test_cases("multiple.txt") inputs = [data[0] for data in test_data] targets = [data[1] for data in test_data] @@ -114,4 +117,4 @@ def test_batch_accuracy(self): match = all(compare_outputs(x, y) for x, y in zip(outputs, targets)) print("Matches:", "No" if not match else "Yes") - assert match == True \ No newline at end of file + assert match == True diff --git a/tests/test_task_decontamination.py b/tests/test_task_decontamination.py index 2c4f4b7f..0341ba7d 100644 --- a/tests/test_task_decontamination.py +++ b/tests/test_task_decontamination.py @@ -12,26 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import pandas as pd -import dask.dataframe as dd from collections import defaultdict +import dask.dataframe as dd +import pandas as pd +import pytest + import nemo_curator from nemo_curator.datasets import DocumentDataset from nemo_curator.tasks import DownstreamTask + class SimpleTask(DownstreamTask): def __init__(self, min_ngram_size=8, max_ngram_size=13): super().__init__() - self._task_name = 'simple' + self._task_name = "simple" self._min_ngram_size = min_ngram_size self._max_ngram_size = max_ngram_size - self._dataset = ["This is a simple example task document with enough words.", - "Two plus two equals four. Two times two equals four as well. However, two divided by two is one.", - "Cars are (sometimes) red. Bicycles can (occasionally) be blue. Trees often have green leaves.", - "This is a short one.", - "This is a short one."] + self._dataset = [ + "This is a simple example task document with enough words.", + "Two plus two equals four. Two times two equals four as well. However, two divided by two is one.", + "Cars are (sometimes) red. Bicycles can (occasionally) be blue. Trees often have green leaves.", + "This is a short one.", + "This is a short one.", + ] def generate_ngrams(self): for line in self._dataset: @@ -39,10 +43,11 @@ def generate_ngrams(self): return self.ngrams + class TinyTask(DownstreamTask): def __init__(self, min_ngram_size=0, max_ngram_size=4): super().__init__() - self._task_name = 'tiny' + self._task_name = "tiny" self._min_ngram_size = min_ngram_size self._max_ngram_size = max_ngram_size self._dataset = ["Super tiny task with one document."] @@ -53,166 +58,250 @@ def generate_ngrams(self): return self.ngrams + def list_to_dataset(documents, col_name="text", npartitions=2): data = {col_name: documents} pdf = pd.DataFrame(data) - + return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + @pytest.fixture def contaminated_dataset(): - return list_to_dataset(["This document is fine", - "Before contamination. This is a simple example task document with enough words. After contamination.", - "This document is not good. Two plus two equals four. Two times two equals four as well. However, one minus one is zero.", - "Long contamination. Birds are (sometimes) red. Bicycles can (occasionally) be blue. Trees often have green leaves. After contamination.", - "Small contamination in a very short document 1. SupeR tiNy task with one document!", - "Small contamination in a very short document 2. Super tiNy task with one document?"], - col_name="text") + return list_to_dataset( + [ + "This document is fine", + "Before contamination. This is a simple example task document with enough words. After contamination.", + "This document is not good. Two plus two equals four. Two times two equals four as well. However, one minus one is zero.", + "Long contamination. Birds are (sometimes) red. Bicycles can (occasionally) be blue. Trees often have green leaves. After contamination.", + "Small contamination in a very short document 1. SupeR tiNy task with one document!", + "Small contamination in a very short document 2. Super tiNy task with one document?", + ], + col_name="text", + ) + class TestPrepareTaskData: def test_single_task(self): decontaminator = nemo_curator.TaskDecontamination(SimpleTask()) actual_count = decontaminator.prepare_task_ngram_count().compute() - expected_ngrams = ["this is a simple example task document with enough words", - "two plus two equals four two times two equals four as well however", - "plus two equals four two times two equals four as well however two", - "two equals four two times two equals four as well however two divided", - "equals four two times two equals four as well however two divided by", - "four two times two equals four as well however two divided by two", - "two times two equals four as well however two divided by two is", - "times two equals four as well however two divided by two is one", - "cars are sometimes red bicycles can occasionally be blue trees often have green", - "are sometimes red bicycles can occasionally be blue trees often have green leaves", - ] + expected_ngrams = [ + "this is a simple example task document with enough words", + "two plus two equals four two times two equals four as well however", + "plus two equals four two times two equals four as well however two", + "two equals four two times two equals four as well however two divided", + "equals four two times two equals four as well however two divided by", + "four two times two equals four as well however two divided by two", + "two times two equals four as well however two divided by two is", + "times two equals four as well however two divided by two is one", + "cars are sometimes red bicycles can occasionally be blue trees often have green", + "are sometimes red bicycles can occasionally be blue trees often have green leaves", + ] expected_count = {ngram: 0 for ngram in expected_ngrams} - assert len(expected_count) == len(actual_count), f"Expected #{len(expected_count)}, got #{len(actual_count)}" - assert expected_count == actual_count, f"Expected: {expected_count}, got: {actual_count}" - + assert len(expected_count) == len( + actual_count + ), f"Expected #{len(expected_count)}, got #{len(actual_count)}" + assert ( + expected_count == actual_count + ), f"Expected: {expected_count}, got: {actual_count}" + def test_multiple_tasks(self): decontaminator = nemo_curator.TaskDecontamination([SimpleTask(), TinyTask()]) actual_count = decontaminator.prepare_task_ngram_count().compute() - expected_ngrams = ["this is a simple example task document with enough words", - "two plus two equals four two times two equals four as well however", - "plus two equals four two times two equals four as well however two", - "two equals four two times two equals four as well however two divided", - "equals four two times two equals four as well however two divided by", - "four two times two equals four as well however two divided by two", - "two times two equals four as well however two divided by two is", - "times two equals four as well however two divided by two is one", - "cars are sometimes red bicycles can occasionally be blue trees often have green", - "are sometimes red bicycles can occasionally be blue trees often have green leaves", - "super tiny task with", - "tiny task with one", - "task with one document", - ] + expected_ngrams = [ + "this is a simple example task document with enough words", + "two plus two equals four two times two equals four as well however", + "plus two equals four two times two equals four as well however two", + "two equals four two times two equals four as well however two divided", + "equals four two times two equals four as well however two divided by", + "four two times two equals four as well however two divided by two", + "two times two equals four as well however two divided by two is", + "times two equals four as well however two divided by two is one", + "cars are sometimes red bicycles can occasionally be blue trees often have green", + "are sometimes red bicycles can occasionally be blue trees often have green leaves", + "super tiny task with", + "tiny task with one", + "task with one document", + ] expected_count = {ngram: 0 for ngram in expected_ngrams} - assert len(expected_count) == len(actual_count), f"Expected #{len(expected_count)}, got #{len(actual_count)}" - assert expected_count == actual_count, f"Expected: {expected_count}, got: {actual_count}" + assert len(expected_count) == len( + actual_count + ), f"Expected #{len(expected_count)}, got #{len(actual_count)}" + assert ( + expected_count == actual_count + ), f"Expected: {expected_count}, got: {actual_count}" + class TestFindMatchingNgrams: def test_single_task(self, contaminated_dataset): decontaminator = nemo_curator.TaskDecontamination(SimpleTask()) task_ngrams = decontaminator.prepare_task_ngram_count() - actual_result = decontaminator.find_matching_ngrams(task_ngrams, contaminated_dataset).compute() - actual_ngrams, actual_freq = actual_result['matched-ngrams'], actual_result['ngrams-freq'] + actual_result = decontaminator.find_matching_ngrams( + task_ngrams, contaminated_dataset + ).compute() + actual_ngrams, actual_freq = ( + actual_result["matched-ngrams"], + actual_result["ngrams-freq"], + ) expected_ngrams = defaultdict(int) expected_ngrams["this is a simple example task document with enough words"] = 1 - expected_ngrams["two plus two equals four two times two equals four as well however"] = 1 - expected_ngrams["are sometimes red bicycles can occasionally be blue trees often have green leaves"] = 1 + expected_ngrams[ + "two plus two equals four two times two equals four as well however" + ] = 1 + expected_ngrams[ + "are sometimes red bicycles can occasionally be blue trees often have green leaves" + ] = 1 expected_freq = [(10, 1), (13, 9)] - assert expected_freq == actual_freq, f"Expected #{expected_freq}, got #{actual_freq}" - assert expected_ngrams == actual_ngrams, f"Expected: {expected_ngrams}, got: {actual_ngrams}" - + assert ( + expected_freq == actual_freq + ), f"Expected #{expected_freq}, got #{actual_freq}" + assert ( + expected_ngrams == actual_ngrams + ), f"Expected: {expected_ngrams}, got: {actual_ngrams}" + def test_multiple_tasks(self, contaminated_dataset): decontaminator = nemo_curator.TaskDecontamination([SimpleTask(), TinyTask()]) task_ngrams = decontaminator.prepare_task_ngram_count() - actual_result = decontaminator.find_matching_ngrams(task_ngrams, contaminated_dataset).compute() - actual_ngrams, actual_freq = actual_result['matched-ngrams'], actual_result['ngrams-freq'] + actual_result = decontaminator.find_matching_ngrams( + task_ngrams, contaminated_dataset + ).compute() + actual_ngrams, actual_freq = ( + actual_result["matched-ngrams"], + actual_result["ngrams-freq"], + ) expected_ngrams = defaultdict(int) expected_ngrams["this is a simple example task document with enough words"] = 1 - expected_ngrams["two plus two equals four two times two equals four as well however"] = 1 - expected_ngrams["are sometimes red bicycles can occasionally be blue trees often have green leaves"] = 1 + expected_ngrams[ + "two plus two equals four two times two equals four as well however" + ] = 1 + expected_ngrams[ + "are sometimes red bicycles can occasionally be blue trees often have green leaves" + ] = 1 expected_ngrams["super tiny task with"] = 2 expected_freq = [(4, 3), (10, 1), (13, 9)] - assert expected_freq == actual_freq, f"Expected #{expected_freq}, got #{actual_freq}" - assert expected_ngrams == actual_ngrams, f"Expected: {expected_ngrams}, got: {actual_ngrams}" + assert ( + expected_freq == actual_freq + ), f"Expected #{expected_freq}, got #{actual_freq}" + assert ( + expected_ngrams == actual_ngrams + ), f"Expected: {expected_ngrams}, got: {actual_ngrams}" + class TestRemoveMatchingNgrams: def test_single_task(self, contaminated_dataset): - decontaminator = nemo_curator.TaskDecontamination(SimpleTask(), min_document_length=1, remove_char_each_side=1) + decontaminator = nemo_curator.TaskDecontamination( + SimpleTask(), min_document_length=1, remove_char_each_side=1 + ) task_ngrams = decontaminator.prepare_task_ngram_count() - ngram_data = decontaminator.find_matching_ngrams(task_ngrams, contaminated_dataset) - matched_ngrams, ngram_freq = ngram_data['matched-ngrams'], ngram_data['ngrams-freq'] - filtered_dataset = decontaminator.remove_matching_ngrams(matched_ngrams, ngram_freq, contaminated_dataset) - - actual_data = sorted(filtered_dataset.df.compute()['text'].to_list()) - expected_data = ["This document is fine", - "Before contamination.", - " After contamination.", - "This document is not good.", - "Long contamination.", - " After contamination.", - "Small contamination in a very short document 1. SupeR tiNy task with one document!", - "Small contamination in a very short document 2. Super tiNy task with one document?"] + ngram_data = decontaminator.find_matching_ngrams( + task_ngrams, contaminated_dataset + ) + matched_ngrams, ngram_freq = ( + ngram_data["matched-ngrams"], + ngram_data["ngrams-freq"], + ) + filtered_dataset = decontaminator.remove_matching_ngrams( + matched_ngrams, ngram_freq, contaminated_dataset + ) + + actual_data = sorted(filtered_dataset.df.compute()["text"].to_list()) + expected_data = [ + "This document is fine", + "Before contamination.", + " After contamination.", + "This document is not good.", + "Long contamination.", + " After contamination.", + "Small contamination in a very short document 1. SupeR tiNy task with one document!", + "Small contamination in a very short document 2. Super tiNy task with one document?", + ] expected_data.sort() - assert expected_data == actual_data, f"Expected #{expected_data}, got #{actual_data}" - + assert ( + expected_data == actual_data + ), f"Expected #{expected_data}, got #{actual_data}" + def test_multiple_tasks(self, contaminated_dataset): - decontaminator = nemo_curator.TaskDecontamination([SimpleTask(), TinyTask()], min_document_length=1, remove_char_each_side=1) + decontaminator = nemo_curator.TaskDecontamination( + [SimpleTask(), TinyTask()], min_document_length=1, remove_char_each_side=1 + ) task_ngrams = decontaminator.prepare_task_ngram_count() - ngram_data = decontaminator.find_matching_ngrams(task_ngrams, contaminated_dataset) - matched_ngrams, ngram_freq = ngram_data['matched-ngrams'], ngram_data['ngrams-freq'] - filtered_dataset = decontaminator.remove_matching_ngrams(matched_ngrams, ngram_freq, contaminated_dataset) - - actual_data = sorted(filtered_dataset.df.compute()['text'].to_list()) - expected_data = ["This document is fine", - "Before contamination.", - " After contamination.", - "This document is not good.", - "Long contamination.", - " After contamination.", - "Small contamination in a very short document 1.", - "Small contamination in a very short document 2."] + ngram_data = decontaminator.find_matching_ngrams( + task_ngrams, contaminated_dataset + ) + matched_ngrams, ngram_freq = ( + ngram_data["matched-ngrams"], + ngram_data["ngrams-freq"], + ) + filtered_dataset = decontaminator.remove_matching_ngrams( + matched_ngrams, ngram_freq, contaminated_dataset + ) + + actual_data = sorted(filtered_dataset.df.compute()["text"].to_list()) + expected_data = [ + "This document is fine", + "Before contamination.", + " After contamination.", + "This document is not good.", + "Long contamination.", + " After contamination.", + "Small contamination in a very short document 1.", + "Small contamination in a very short document 2.", + ] expected_data.sort() - assert expected_data == actual_data, f"Expected #{expected_data}, got #{actual_data}" + assert ( + expected_data == actual_data + ), f"Expected #{expected_data}, got #{actual_data}" + class TestFullPipeline: def test_single_task(self, contaminated_dataset): - decontaminator = nemo_curator.TaskDecontamination(SimpleTask(), min_document_length=1, remove_char_each_side=1) + decontaminator = nemo_curator.TaskDecontamination( + SimpleTask(), min_document_length=1, remove_char_each_side=1 + ) filtered_dataset = decontaminator(contaminated_dataset) - actual_data = sorted(filtered_dataset.df.compute()['text'].to_list()) - expected_data = ["This document is fine", - "Before contamination.", - " After contamination.", - "This document is not good.", - "Long contamination.", - " After contamination.", - "Small contamination in a very short document 1. SupeR tiNy task with one document!", - "Small contamination in a very short document 2. Super tiNy task with one document?"] + actual_data = sorted(filtered_dataset.df.compute()["text"].to_list()) + expected_data = [ + "This document is fine", + "Before contamination.", + " After contamination.", + "This document is not good.", + "Long contamination.", + " After contamination.", + "Small contamination in a very short document 1. SupeR tiNy task with one document!", + "Small contamination in a very short document 2. Super tiNy task with one document?", + ] expected_data.sort() - assert expected_data == actual_data, f"Expected #{expected_data}, got #{actual_data}" - + assert ( + expected_data == actual_data + ), f"Expected #{expected_data}, got #{actual_data}" + def test_multiple_tasks(self, contaminated_dataset): - decontaminator = nemo_curator.TaskDecontamination([SimpleTask(), TinyTask()], min_document_length=1, remove_char_each_side=1) + decontaminator = nemo_curator.TaskDecontamination( + [SimpleTask(), TinyTask()], min_document_length=1, remove_char_each_side=1 + ) filtered_dataset = decontaminator(contaminated_dataset) - actual_data = sorted(filtered_dataset.df.compute()['text'].to_list()) - expected_data = ["This document is fine", - "Before contamination.", - " After contamination.", - "This document is not good.", - "Long contamination.", - " After contamination.", - "Small contamination in a very short document 1.", - "Small contamination in a very short document 2."] + actual_data = sorted(filtered_dataset.df.compute()["text"].to_list()) + expected_data = [ + "This document is fine", + "Before contamination.", + " After contamination.", + "This document is not good.", + "Long contamination.", + " After contamination.", + "Small contamination in a very short document 1.", + "Small contamination in a very short document 2.", + ] expected_data.sort() - assert expected_data == actual_data, f"Expected #{expected_data}, got #{actual_data}" \ No newline at end of file + assert ( + expected_data == actual_data + ), f"Expected #{expected_data}, got #{actual_data}" diff --git a/tests/test_unicode_reformatter.py b/tests/test_unicode_reformatter.py index 56f022e4..01ac716b 100644 --- a/tests/test_unicode_reformatter.py +++ b/tests/test_unicode_reformatter.py @@ -19,29 +19,33 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.modifiers import UnicodeReformatter + def list_to_dataset(documents, col_name="text", npartitions=2): data = {col_name: documents} pdf = pd.DataFrame(data) - + return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + class TestUnicodeReformatter: def test_reformatting(self): # Examples taken from ftfy documentation: # https://ftfy.readthedocs.io/en/latest/ - dataset = list_to_dataset([ - "✔ No problems", - "The Mona Lisa doesn’t have eyebrows.", - "l’humanité", - "à perturber la réflexion", - "Clean document already." - ]) + dataset = list_to_dataset( + [ + "✔ No problems", + "The Mona Lisa doesn’t have eyebrows.", + "l’humanité", + "à perturber la réflexion", + "Clean document already.", + ] + ) expected_results = [ "✔ No problems", "The Mona Lisa doesn't have eyebrows.", "l'humanité", "à perturber la réflexion", - "Clean document already." + "Clean document already.", ] expected_results.sort() @@ -50,4 +54,6 @@ def test_reformatting(self): actual_results = fixed_dataset.df.compute()["text"].to_list() actual_results.sort() - assert expected_results == actual_results, f"Expected: {expected_results}, but got: {actual_results}" \ No newline at end of file + assert ( + expected_results == actual_results + ), f"Expected: {expected_results}, but got: {actual_results}" diff --git a/tutorials/tinystories/helpers.py b/tutorials/tinystories/helpers.py index ac486a56..45215dcc 100644 --- a/tutorials/tinystories/helpers.py +++ b/tutorials/tinystories/helpers.py @@ -15,10 +15,7 @@ import json import os -from docbuilder import ( - TinyStoriesExtractor, - TinyStoriesIterator, -) +from docbuilder import TinyStoriesExtractor, TinyStoriesIterator def write_jsonl(input_filename: str, output_dir: str, dump_every_n: int = 10000): diff --git a/tutorials/tinystories/main.py b/tutorials/tinystories/main.py index cd0c1096..13df0712 100644 --- a/tutorials/tinystories/main.py +++ b/tutorials/tinystories/main.py @@ -24,10 +24,7 @@ from nemo_curator import ScoreFilter, Sequential from nemo_curator.datasets import DocumentDataset -from nemo_curator.filters import ( - RepeatingTopNGramsFilter, - WordCountFilter, -) +from nemo_curator.filters import RepeatingTopNGramsFilter, WordCountFilter from nemo_curator.modifiers.pii_modifier import PiiModifierBatched from nemo_curator.modifiers.unicode_reformatter import UnicodeReformatter from nemo_curator.modules import ExactDuplicates