From f4606fdb7488a90cd8514548c0402d9b2780f32e Mon Sep 17 00:00:00 2001 From: Mehran Maghoumi Date: Fri, 10 May 2024 10:25:40 -0700 Subject: [PATCH] [Tutorials] Add a tutorial for PEFT data curation (#45) This PR adds a new tutorial to demonstrate data curation for PEFT use-cases. Signed-off-by: Mehran Maghoumi Signed-off-by: Vibhu Jawa --- tutorials/peft-curation/README.md | 19 +++ tutorials/peft-curation/docbuilder.py | 113 ++++++++++++++++ tutorials/peft-curation/filters.py | 47 +++++++ tutorials/peft-curation/main.py | 179 ++++++++++++++++++++++++++ tutorials/peft-curation/modifiers.py | 68 ++++++++++ tutorials/tinystories/README.md | 2 +- tutorials/tinystories/main.py | 6 +- 7 files changed, 432 insertions(+), 2 deletions(-) create mode 100644 tutorials/peft-curation/README.md create mode 100644 tutorials/peft-curation/docbuilder.py create mode 100644 tutorials/peft-curation/filters.py create mode 100644 tutorials/peft-curation/main.py create mode 100644 tutorials/peft-curation/modifiers.py diff --git a/tutorials/peft-curation/README.md b/tutorials/peft-curation/README.md new file mode 100644 index 000000000..afa0d66a3 --- /dev/null +++ b/tutorials/peft-curation/README.md @@ -0,0 +1,19 @@ +# Curating Datasets for Parameter Efficient Fine-tuning + +This tutorial demonstrates the usage of NeMo Curator's Python API to curate a dataset for +parameter-efficient fine-tuning (PEFT). + +In this tutorial, we use the [Enron Emails dataset](https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning), +which is a dataset of emails with corresponding classification labels for each email. Each email has +a subject, a body and a category (class label). We demonstrate various filtering and processing +operations that can be applied to each record. + +## Usage +After installing the NeMo Curator package, you can simply run the following command: +``` +python tutorials/peft-curation/main.py +``` + +By default, this tutorial will use at most 8 workers to run the curation pipeline. If you face any +out of memory issues, you can reduce the number of workers by supplying the `--n-workers=N` argument, +where `N` is the number of workers to spawn. diff --git a/tutorials/peft-curation/docbuilder.py b/tutorials/peft-curation/docbuilder.py new file mode 100644 index 000000000..3ae0840c9 --- /dev/null +++ b/tutorials/peft-curation/docbuilder.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from typing import Dict + +import requests + +from nemo_curator.download.doc_builder import ( + DocumentDownloader, + DocumentExtractor, + DocumentIterator, +) + + +class EmailsDownloader(DocumentDownloader): + def __init__(self, download_dir: str): + super().__init__() + + if not os.path.isdir(download_dir): + os.makedirs(download_dir) + + self._download_dir = download_dir + print("Download directory: ", self._download_dir) + + def download(self, url: str) -> str: + filename = os.path.basename(url) + output_file = os.path.join(self._download_dir, filename) + + if os.path.exists(output_file): + print(f"File '{output_file}' already exists, skipping download.") + return output_file + + print(f"Downloading Enron emails dataset from '{url}'...") + response = requests.get(url) + + with open(output_file, "wb") as file: + file.write(response.content) + + return output_file + + +class EmailsIterator(DocumentIterator): + + def __init__(self): + super().__init__() + self._counter = -1 + self._extractor = EmailsExtractor() + # The regular expression pattern to extract each email. + self._pattern = re.compile(r"\".*?\"", re.DOTALL) + + def iterate(self, file_path): + self._counter = -1 + file_name = os.path.basename(file_path) + + with open(file_path, "r", encoding="utf-8") as file: + lines = file.readlines() + + # Ignore the first line which contains the header. + file_content = "".join(lines[1:]) + # Find all the emails in the file. + it = self._pattern.finditer(file_content) + + for email in it: + self._counter += 1 + content = email.group().strip('"').strip() + meta = { + "filename": file_name, + "id": f"email-{self._counter}", + } + extracted_content = self._extractor.extract(content) + + # Skip if no content extracted + if not extracted_content: + continue + + record = {**meta, **extracted_content} + yield record + + +class EmailsExtractor(DocumentExtractor): + def __init__(self): + super().__init__() + # The regular expression pattern to extract subject/body/label into groups. + self._pattern = re.compile( + r"Subject:: (.*?)\nBody:: (.*?)\n.*\[/INST\] (.*?) ", re.DOTALL + ) + + def extract(self, content: str) -> Dict[str, str]: + matches = self._pattern.findall(content) + + if not matches: + return None + + matches = matches[0] + + return { + "subject": matches[0].strip(), + "body": matches[1].strip(), + "category": matches[2].strip(), + } diff --git a/tutorials/peft-curation/filters.py b/tutorials/peft-curation/filters.py new file mode 100644 index 000000000..0ffcd5be7 --- /dev/null +++ b/tutorials/peft-curation/filters.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.filters import DocumentFilter + + +class FilterEmailsWithLongBody(DocumentFilter): + """ + If the email is too long, discard. + """ + + def __init__(self, max_length: int = 5000): + super().__init__() + self.max_length = max_length + + def score_document(self, text: str) -> bool: + return len(text) <= self.max_length + + def keep_document(self, score) -> bool: + return score + + +class FilterEmptyEmails(DocumentFilter): + """ + Detects empty emails (either empty body, or labeled as empty). Returns `True` for empty emails. + """ + + def score_document(self, text: str) -> bool: + return ( + not isinstance(text, str) # The text is not a string + or len(text.strip()) == 0 # The text is empty + or "Empty message" in text # The email is labeled as empty + ) + + def keep_document(self, score) -> bool: + return score diff --git a/tutorials/peft-curation/main.py b/tutorials/peft-curation/main.py new file mode 100644 index 000000000..9210d9f89 --- /dev/null +++ b/tutorials/peft-curation/main.py @@ -0,0 +1,179 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +from functools import partial +from typing import Any + +from docbuilder import EmailsDownloader, EmailsIterator +from filters import FilterEmailsWithLongBody, FilterEmptyEmails +from modifiers import AddPeriod, AddSystemPrompt + +from nemo_curator import ScoreFilter, Sequential +from nemo_curator.datasets import DocumentDataset +from nemo_curator.modifiers.pii_modifier import PiiModifier +from nemo_curator.modifiers.unicode_reformatter import UnicodeReformatter +from nemo_curator.modules.modify import Modify +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.script_utils import add_distributed_args + +SCRIPT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) +DATA_DIR = os.path.join(SCRIPT_DIR_PATH, "data") +DATASET_URL = "https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning/raw/main/prompts_train.csv" + + +def download_and_convert_to_jsonl() -> str: + """ + Downloads the emails dataset and converts it to JSONL format. + + Returns: + str: The path to the JSONL file. + """ + + # Download the dataset in raw format and convert it to JSONL. + downloader = EmailsDownloader(DATA_DIR) + output_path = os.path.join(DATA_DIR, "emails.jsonl") + raw_fp = downloader.download(DATASET_URL) + + iterator = EmailsIterator() + + # Parse the raw data and write it to a JSONL file. + with open(output_path, "w") as f: + for record in iterator.iterate(raw_fp): + json_record = json.dumps(record, ensure_ascii=False) + f.write(json_record + "\n") + + return output_path + + +def redact_pii(dataset: DocumentDataset, text_field) -> DocumentDataset: + """ + Redacts personally identifiable information (PII) from a given dataset. + + Args: + dataset (DocumentDataset): The dataset containing documents with PII. + + Returns: + DocumentDataset: The redacted dataset with PII replaced by a generic value. + """ + redactor = Modify( + PiiModifier( + supported_entities=[ + "ADDRESS", + "EMAIL_ADDRESS", + "LOCATION", + "PERSON", + "URL", + "PHONE_NUMBER", + ], + anonymize_action="replace", + device="cpu", + ), + text_field=text_field, + ) + return redactor(dataset) + + +def run_curation_pipeline(args: Any, jsonl_fp: str) -> str: + """ + Run the curation pipeline on the dataset. + + Args: + args (Any): Command-line arguments. + jsonl_fp (str): The path to the uncurated JSONL file. + + Returns: + str: The path to the curated JSONL file. + """ + client = get_client(args, args.device) + print(f" Running the curation pipeline on '{jsonl_fp}'...") + orig_dataset = DocumentDataset.read_json(jsonl_fp, add_filename=True) + dataset = orig_dataset + + redact_pii_subject = partial(redact_pii, text_field="subject") + redact_pii_body = partial(redact_pii, text_field="body") + + curation_steps = Sequential( + [ + # + # Unify the text encoding to Unicode. + # + Modify(UnicodeReformatter(), text_field="subject"), + Modify(UnicodeReformatter(), text_field="body"), + Modify(UnicodeReformatter(), text_field="category"), + # + # Filtering + # + # Filter out empty emails. + ScoreFilter( + FilterEmptyEmails(), text_field="subject", score_type=bool, invert=True + ), + ScoreFilter( + FilterEmptyEmails(), text_field="body", score_type=bool, invert=True + ), + ScoreFilter( + FilterEmptyEmails(), text_field="category", score_type=bool, invert=True + ), + # Filter out emails that are too long. + ScoreFilter(FilterEmailsWithLongBody(), text_field="body", score_type=bool), + # + # Redact personally identifiable information (PII). + # + redact_pii_subject, + redact_pii_body, + # + # Final modifications. + # + # Add system prompts to every email, which helps the model focus on the task. + Modify(AddSystemPrompt(), text_field="body"), + # Add a period to the end of each email category, which makes PEFT easier. + Modify(AddPeriod(), text_field="category"), + ] + ) + + dataset = curation_steps(dataset) + dataset = dataset.persist() + + print(f" Original dataset length: {len(orig_dataset.df)}") + print(f" After running the curation pipeline: {len(dataset.df)}") + print(f" Writing to '{jsonl_fp}'...") + out_path = os.path.join( + os.path.dirname(jsonl_fp), + "curated", + ) + os.makedirs(out_path, exist_ok=True) + dataset.to_json(out_path, write_to_filename=True) + client.close() + return os.path.join(out_path, os.path.basename(jsonl_fp)) + + +def main(): + parser = argparse.ArgumentParser() + parser = add_distributed_args(parser) + args = parser.parse_args() + # Limit the total number of workers to ensure we don't run out of memory. + args.n_workers = min(args.n_workers, 8) + + # Prepare the download and JSONL directories. + if not os.path.isdir(DATA_DIR): + os.makedirs(DATA_DIR) + + jsonl_fp = download_and_convert_to_jsonl() + run_curation_pipeline(args, jsonl_fp) + + +if __name__ == "__main__": + main() diff --git a/tutorials/peft-curation/modifiers.py b/tutorials/peft-curation/modifiers.py new file mode 100644 index 000000000..059036ee4 --- /dev/null +++ b/tutorials/peft-curation/modifiers.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.modifiers import DocumentModifier + +# The system prompt template to be inserted into the documents. +SYS_PROMPT_TEMPLATE = """[INST] <> You are reviewing the contents of an email. Based on the content, please categorize this email into one of the following categories: +1. 'Company Business/Strategy.' +2. 'Purely Personal.' +3. 'Personal but in a professional context.' +4. 'Logistic Arrangements.' +5. 'Employment arrangements.' +6. 'Document editing/checking/collaboration.' +Please provide only one category (e.g., 'Purely Personal.'). <> + +Content:: +%s + +What should this email be categorized as? +[/INST] +Answer:: """ + + +class AddSystemPrompt(DocumentModifier): + """ + A simple modifier that adds system prompts to each document. + """ + + def modify_document(self, text: str) -> str: + """ + Inserts system prompts into the document. + + Args: + text (str): The text to be modified. + + Returns: + str: The modified text. + """ + return SYS_PROMPT_TEMPLATE % text + + +class AddPeriod(DocumentModifier): + """ + A simple modifier that adds a period to the end of each email category. + """ + + def modify_document(self, text: str) -> str: + """ + Adds a period to the end of each email category. + + Args: + text (str): The text to be modified. + + Returns: + str: The modified text. + """ + return text + "." diff --git a/tutorials/tinystories/README.md b/tutorials/tinystories/README.md index 47074cb3f..45bc3bf33 100644 --- a/tutorials/tinystories/README.md +++ b/tutorials/tinystories/README.md @@ -1,6 +1,6 @@ # TinyStories -This tutorial demonstrates the usage of NeMo Curator's Python API to curate the [TinyStories](https://arxiv.org/abs/2305.07759) dataset. TinyStories is a dataset of short stories generated by GPT-3.5 and GPT-4, featuring words that are undersood by 3 to 4-year olds. The small size of this dataset makes it ideal for creating and validating data curation pipelines on a local machine. +This tutorial demonstrates the usage of NeMo Curator's Python API to curate the [TinyStories](https://arxiv.org/abs/2305.07759) dataset. TinyStories is a dataset of short stories generated by GPT-3.5 and GPT-4, featuring words that are understood by 3 to 4-year olds. The small size of this dataset makes it ideal for creating and validating data curation pipelines on a local machine. For simplicity, this tutorial uses the validation split of this dataset, which contains around 22,000 samples. diff --git a/tutorials/tinystories/main.py b/tutorials/tinystories/main.py index fa4470c35..1fbbba35c 100644 --- a/tutorials/tinystories/main.py +++ b/tutorials/tinystories/main.py @@ -97,19 +97,23 @@ def filter_dataset(dataset: DocumentDataset) -> DocumentDataset: WordCountFilter(min_words=80), text_field="text", score_field="word_count", + score_type=int, ), - ScoreFilter(IncompleteStoryFilter(), text_field="text"), + ScoreFilter(IncompleteStoryFilter(), text_field="text", score_type=bool), ScoreFilter( RepeatingTopNGramsFilter(n=2, max_repeating_ngram_ratio=0.2), text_field="text", + score_type=float, ), ScoreFilter( RepeatingTopNGramsFilter(n=3, max_repeating_ngram_ratio=0.18), text_field="text", + score_type=float, ), ScoreFilter( RepeatingTopNGramsFilter(n=4, max_repeating_ngram_ratio=0.16), text_field="text", + score_type=float, ), ] )