Skip to content

Commit

Permalink
[Tutorials] Add a tutorial for PEFT data curation
Browse files Browse the repository at this point in the history
This PR adds a new tutorial to demonstrate data curation for PEFT
use-cases.

Signed-off-by: Mehran Maghoumi <Maghoumi@users.noreply.github.com>
  • Loading branch information
Maghoumi committed Apr 30, 2024
1 parent f4355af commit c853919
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 1 deletion.
15 changes: 15 additions & 0 deletions tutorials/peft-curation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Curating Datasets for Parameter Efficient Fine-tuning

This tutorial demonstrates the usage of NeMo Curator's Python API to curate a dataset for
parameter-efficient fine-tuning (PEFT).

In this tutorial, we use the [Enron Emails dataset](https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning),
which is a dataset of emails with corresponding classification labels for each email. Each email has
a subject, a body and a category (class label). We demonstrate various filtering and processing
operations that can be applied to each record.

## Usage
After installing the NeMo Curator package, you can simply run the following command:
```
python tutorials/peft-curation/main.py
```
111 changes: 111 additions & 0 deletions tutorials/peft-curation/docbuilder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
from typing import Set, Tuple

import requests

from nemo_curator.download.doc_builder import (
DocumentDownloader,
DocumentExtractor,
DocumentIterator,
)


class EmailsDownloader(DocumentDownloader):
def __init__(self, download_dir: str):
super().__init__()

if not os.path.isdir(download_dir):
os.makedirs(download_dir)

self._download_dir = download_dir
print("Download directory: ", self._download_dir)

def download(self, url: str) -> str:
filename = os.path.basename(url)
output_file = os.path.join(self._download_dir, filename)

if os.path.exists(output_file):
print(f"File '{output_file}' already exists, skipping download.")
return output_file

print(f"Downloading Enron emails dataset from '{url}'...")
response = requests.get(url)

with open(output_file, "wb") as file:
file.write(response.content)

return output_file


class EmailsIterator(DocumentIterator):

def __init__(self):
super().__init__()
self._counter = -1
self._extractor = EmailsExtractor()
# The regular expression pattern to extract each email.
self._pattern = re.compile(r'\"<s>.*?<s>\"', re.DOTALL)

def iterate(self, file_path):
self._counter = -1
file_name = os.path.basename(file_path)

with open(file_path, 'r', encoding='utf-8') as file:
lines = file.readlines()

# Ignore the first line which contains the header.
file_content = "".join(lines[1:])
# Find all the emails in the file.
it = self._pattern.finditer(file_content)

for email in it:
self._counter += 1
content = email.group().strip('"').strip()
meta = {
"filename": file_name,
"id": f"email-{self._counter}",
}
extracted_content = self._extractor.extract(content)

# Skip if no content extracted
if not extracted_content:
continue

record = {**meta, **extracted_content}
yield record


class EmailsExtractor(DocumentExtractor):
def __init__(self):
super().__init__()
# The regular expression pattern to extract subject/body/label into groups.
self._pattern = re.compile(r'Subject:: (.*?)\nBody:: (.*?)\n.*\[/INST\] (.*?) <s>', re.DOTALL)

def extract(self, content: str) -> Tuple[Set, str]:
matches = self._pattern.findall(content)

if not matches:
return None

matches = matches[0]

return {
"subject": matches[0].strip(),
"body": matches[1].strip(),
"category": matches[2].strip(),
}
41 changes: 41 additions & 0 deletions tutorials/peft-curation/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_curator.filters import DocumentFilter


class FilterEmailsWithLongBody(DocumentFilter):
def __init__(self, max_length: int = 5000):
super().__init__()
self.max_length = max_length

def score_document(self, text: str) -> bool:
return isinstance(text, str) and len(text) <= self.max_length

def keep_document(self, score) -> bool:
return score


class FilterEmptyEmails(DocumentFilter):
def score_document(self, text: str) -> bool:
return (
isinstance(text, str)
and (
len(text.strip()) > 0
or "Empty message" in text
)
)

def keep_document(self, score) -> bool:
return score
158 changes: 158 additions & 0 deletions tutorials/peft-curation/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from functools import partial
import json
import os
from typing import Any

from docbuilder import EmailsDownloader, EmailsIterator
from filters import FilterEmailsWithLongBody, FilterEmptyEmails

from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.modify import Modify
from nemo_curator import ScoreFilter, Sequential
from nemo_curator.modifiers.pii_modifier import PiiModifier
from nemo_curator.modifiers.unicode_reformatter import UnicodeReformatter
from nemo_curator.utils.script_utils import add_distributed_args


SCRIPT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(SCRIPT_DIR_PATH, "data")
DATASET_URL = "https://huggingface.co/datasets/neelblabla/enron_labeled_emails_with_subjects-llama2-7b_finetuning/raw/main/prompts_train.csv"


def download_and_convert_to_jsonl() -> str:
"""
Downloads the emails dataset and converts it to JSONL format.
Returns:
str: The path to the JSONL file.
"""

# Download the dataset in raw format and convert it to JSONL.
downloader = EmailsDownloader(DATA_DIR)
output_path = os.path.join(DATA_DIR, "emails.jsonl")
raw_fp = downloader.download(DATASET_URL)

iterator = EmailsIterator()

# Parse the raw data and write it to a JSONL file.
with open(output_path, 'w') as f:
for record in iterator.iterate(raw_fp):
json_record = json.dumps(record, ensure_ascii=False)
f.write(json_record + "\n")

return output_path


def redact_pii(dataset: DocumentDataset, text_field) -> DocumentDataset:
"""
Redacts personally identifiable information (PII) from a given dataset.
Args:
dataset (DocumentDataset): The dataset containing documents with PII.
Returns:
DocumentDataset: The redacted dataset with PII replaced by a generic value.
"""
redactor = Modify(
PiiModifier(
supported_entities=[
"ADDRESS",
"EMAIL_ADDRESS",
"LOCATION",
"PERSON",
"URL",
"PHONE_NUMBER",
],
anonymize_action="replace",
device="cpu",
),
text_field=text_field,
)
return redactor(dataset)


def run_curation_pipeline(args: Any, jsonl_fp: str) -> str:
"""
Run the curation pipeline on the dataset.
Args:
args (Any): Command-line arguments.
jsonl_fp (str): The path to the uncurated JSONL file.
Returns:
str: The path to the curated JSONL file.
"""
client = get_client(args, args.device)
print(f" Running the curation pipeline on '{jsonl_fp}'...")
orig_dataset = DocumentDataset.read_json(jsonl_fp, add_filename=True)
dataset = orig_dataset

# FIXME I wish it was possible to take two columns into account?
redact_pii_subject = partial(redact_pii, text_field="subject")
redact_pii_body = partial(redact_pii, text_field="body")

curation_steps = Sequential(
[
# Unify the text encoding to Unicode.
Modify(UnicodeReformatter(), text_field="subject"),
Modify(UnicodeReformatter(), text_field="body"),
Modify(UnicodeReformatter(), text_field="category"),
# Filter out empty emails.
ScoreFilter(FilterEmptyEmails(), text_field="subject"),
ScoreFilter(FilterEmptyEmails(), text_field="body"),
# Filter out emails that are too long.
ScoreFilter(FilterEmailsWithLongBody(), text_field="body"),
# Redact personally identifiable information (PII).
redact_pii_subject,
redact_pii_body,
]
)
dataset = curation_steps(dataset)
dataset = dataset.persist()

print(f" Original dataset length: {len(orig_dataset.df)}")
print(f" After running the curation pipeline: {len(dataset.df)}")
print(f" Writing to '{jsonl_fp}'...")
out_path = os.path.join(
os.path.dirname(jsonl_fp),
"curated",
)
os.makedirs(out_path, exist_ok=True)
dataset.to_json(out_path, write_to_filename=True)
client.close()
return os.path.join(out_path, os.path.basename(jsonl_fp))


def main():
parser = argparse.ArgumentParser()
parser = add_distributed_args(parser)
args = parser.parse_args()
# Limit the total number of workers to ensure we don't run out of memory.
args.n_workers = min(args.n_workers, 1)

# Prepare the download and JSONL directories.
if not os.path.isdir(DATA_DIR):
os.makedirs(DATA_DIR)

jsonl_fp = download_and_convert_to_jsonl()
run_curation_pipeline(args, jsonl_fp)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion tutorials/tinystories/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TinyStories

This tutorial demonstrates the usage of NeMo Curator's Python API to curate the [TinyStories](https://arxiv.org/abs/2305.07759) dataset. TinyStories is a dataset of short stories generated by GPT-3.5 and GPT-4, featuring words that are undersood by 3 to 4-year olds. The small size of this dataset makes it ideal for creating and validating data curation pipelines on a local machine.
This tutorial demonstrates the usage of NeMo Curator's Python API to curate the [TinyStories](https://arxiv.org/abs/2305.07759) dataset. TinyStories is a dataset of short stories generated by GPT-3.5 and GPT-4, featuring words that are understood by 3 to 4-year olds. The small size of this dataset makes it ideal for creating and validating data curation pipelines on a local machine.

For simplicity, this tutorial uses the validation split of this dataset, which contains around 22,000 samples.

Expand Down

0 comments on commit c853919

Please sign in to comment.