Skip to content

Commit

Permalink
Fix 255 - Improve separate_by_metadata performance for jsonl files (N…
Browse files Browse the repository at this point in the history
…VIDIA#256)

* Improve performance in jsonl files

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Improve performance in jsonl files

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Shutdown Dask cluster at exit

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Remove unneeded persist() and wait() operations

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Display only Dask error messages or above

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Cancel any remaining futures

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Remove Dask warning message

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Rename new arguments

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

* Refactor separate_by_metadata

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>

---------

Signed-off-by: miguelusque <miguelusque@users.noreply.github.com>
Co-authored-by: miguelusque <miguelusque@users.noreply.github.com>
  • Loading branch information
miguelusque and miguelusque authored Oct 1, 2024
1 parent 802ae31 commit 63d87c3
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 43 deletions.
80 changes: 52 additions & 28 deletions nemo_curator/scripts/separate_by_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,53 @@

import argparse
import json
import logging
import shutil

from nemo_curator.utils.distributed_utils import get_client, read_data
from nemo_curator.utils.file_utils import (
expand_outdir_and_mkdir,
get_all_files_paths_under,
separate_by_metadata,
)
from dask.distributed.utils import silence_logging_cmgr

from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.file_utils import separate_by_metadata
from nemo_curator.utils.script_utils import ArgumentHelper


def main(args):
client = get_client(**ArgumentHelper.parse_client_args(args))
print(f"Beginning metadata separation for {args.input_metadata_field}")

files = get_all_files_paths_under(args.input_data_dir)
input_data = read_data(
files, file_type=args.input_file_type, backend="pandas", add_filename=True
)
with silence_logging_cmgr(logging.ERROR):
# Initializes a Dask cluster.
client = get_client(**ArgumentHelper.parse_client_args(args))

# Separete corpus by metadata
metadata_distribution = separate_by_metadata(
input_data=args.input_data_dir,
output_dir=args.output_data_dir,
metadata_field=args.input_metadata_field,
remove_metadata=args.remove_metadata_field,
output_type=args.output_file_type,
input_type=args.input_file_type,
include_values=args.include_values,
exclude_values=args.exclude_values,
)

output_dir = expand_outdir_and_mkdir(args.output_data_dir)
# Save metadata distribution to disk
with open(args.output_metadata_distribution, "w") as fp:
json.dump(metadata_distribution.compute(), fp)

metadata_field = args.input_metadata_field
print(f"Beginning metadata separation for {metadata_field}")
metadata_distribution = separate_by_metadata(
input_data,
output_dir,
metadata_field,
remove_metadata=args.remove_metadata_field,
output_type=args.output_file_type,
).compute()
print(f"Finished metadata separation for {metadata_field}")
# Optionally, remove input directory
if args.remove_input_dir:
print(f"Removing all files in {args.input_data_dir}")
shutil.rmtree(args.input_data_dir)
print(f"Finished removing all files in {args.input_data_dir}")

with open(args.output_metadata_distribution, "w") as fp:
json.dump(metadata_distribution, fp)
# Cancel any remaining futures (if any)
client.cancel(metadata_distribution)

if args.remove_input_dir:
print(f"Removing all files in {args.input_data_dir}")
shutil.rmtree(args.input_data_dir)
print(f"Finished removing all files in {args.input_data_dir}")
# Shut down the cluster
client.shutdown()

# Close the client
client.close()


def attach_args(
Expand Down Expand Up @@ -103,6 +111,22 @@ def attach_args(
"is desired to be separated from the others",
)

exclusive_filters_group = parser.add_mutually_exclusive_group(required=False)
exclusive_filters_group.add_argument(
"--include-values",
nargs="+",
type=str,
help="A list of strings representing specific values to be selected or included. "
"If provided, only the items matching these values should be kept.",
)
exclusive_filters_group.add_argument(
"--exclude-values",
nargs="+",
type=str,
help="A list of strings representing specific values to be excluded or ignored. "
"If provided, any items matching these values should be skipped.",
)

return parser


Expand Down
141 changes: 127 additions & 14 deletions nemo_curator/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
import pathlib
from functools import partial, reduce
from typing import List, Union

import dask.bag as db
import dask.dataframe as dd
import numpy as np
import pandas as pd
from dask import delayed

from nemo_curator.utils.distributed_utils import single_partition_write_with_filename
from nemo_curator.utils.distributed_utils import (
read_data,
single_partition_write_with_filename,
)

NEMO_CURATOR_HOME = os.environ.get(
"NEMO_CURATOR_HOME", os.path.join(os.path.expanduser("~"), ".nemo_curator")
Expand Down Expand Up @@ -96,7 +101,7 @@ def get_remaining_files(
for entry in os.scandir(input_file_path)
if os.path.basename(entry.path) not in completed_files
]
# Gaurd against non extension files if present in the input directory
# Guard against non extension files if present in the input directory
input_files = [f for f in input_files if f.endswith(input_file_type)]
input_files.sort()

Expand Down Expand Up @@ -131,13 +136,26 @@ def get_batched_files(


def write_dataframe_by_meta(
df: pd.DataFrame, output_dir, metadata_field, remove_metadata, output_type
df: pd.DataFrame,
output_dir: str,
metadata_field: str,
remove_metadata: bool = False,
output_type: str = "jsonl",
include_values: List[str] = None,
exclude_values: List[str] = None,
):
counts = df[metadata_field].value_counts().to_dict()

# Apply include_values or value_exclesion_filter if provided
if include_values is not None and include_values:
counts = {k: v for k, v in counts.items() if k in include_values}
elif exclude_values is not None and exclude_values:
counts = {k: v for k, v in counts.items() if k not in exclude_values}

for meta_value in counts:
meta_output_dir = expand_outdir_and_mkdir(os.path.join(output_dir, meta_value))
meta_slice = df[df[metadata_field] == meta_value]

if remove_metadata:
meta_slice = meta_slice.drop(columns=[metadata_field])
single_partition_write_with_filename(
Expand All @@ -154,35 +172,130 @@ def merge_counts(first: dict, second: dict):
return first


def write_record(
input_dir: str,
file_name: str,
line: str,
field: str,
output_dir: str,
include_values: List[str] = None,
exclude_values: List[str] = None,
):
try:
# Parse the JSON-encoded string 'line' into a Python dictionary
line = json.loads(line)

# Select category value
category = line[field]

if (exclude_values and category in exclude_values) or (
include_values and category not in include_values
):
return None

# Obtain the relative path
rel_path, file_name = os.path.split(
os.path.relpath(file_name, start=os.path.abspath(input_dir))
)

output_dir = os.path.join(output_dir, category, rel_path)
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/{file_name}", "a") as f:
f.write(json.dumps(line) + "\n")

return category
except (KeyError, ValueError, json.JSONDecodeError):
return None


def separate_by_metadata(
df: dd.DataFrame,
output_dir,
metadata_field,
remove_metadata=False,
output_type="jsonl",
input_data: Union[dd.DataFrame, str],
output_dir: str,
metadata_field: str,
remove_metadata: bool = False,
output_type: str = "jsonl",
input_type: str = "jsonl",
include_values: List[str] = None,
exclude_values: List[str] = None,
) -> dict:
"""
Saves the dataframe to subfolders named after a metadata
Args:
df: The dataframe to write. Must have a filename column for the shard.
input_data: Either a DataFrame or a string representing the path to the input directory.
If a DataFrame is provided, it must have a 'filename' column for the shard.
output_dir: The base directory for which all metadata based subdirs will be created under
metadata_field: The metadata field to split on
remove_metadata: Whether to remove the metadata from the dataframe when saving it
output_type: File type the dataset will be written to. Supported file formats include 'jsonl' (default),
'pickle', or 'parquet'. (default: jsonl)
include_values: A list of strings representing specific values to be selected or included.
If provided, only the items matching these values should be kept.
exclude_values: A list of strings representing specific values to be excluded or ignored.
If provided, any items matching these values should be skipped.
Returns:
A delayed dictionary mapping each metadata to the count of entries with that metadata value.
"""
delayed_data = df.to_delayed()

if include_values is not None and exclude_values is not None:
print("Error: 'include_values' and 'exclude_values' are mutually exclusive.")

return

# Create output_dir if needed
if output_dir:
output_dir = expand_outdir_and_mkdir(output_dir)

if isinstance(input_data, str):
print(f"Reading {input_type} files from {input_data}", flush=True)

if input_type in ["json", "jsonl"] and output_type in ["json", "jsonl"]:
# Read JSONL files with streaming (line-by-line), and include file path
bag = db.read_text(
os.path.join(input_data, "**", f"*.{input_type}"),
include_path=True,
)

# Parse JSON lines and retain the file path
bag = bag.map(
lambda x: write_record(
input_dir=input_data,
file_name=x[1],
line=x[0],
field=metadata_field,
output_dir=output_dir,
include_values=include_values,
exclude_values=exclude_values,
)
)

frequencies = dict(bag.frequencies().compute())
frequencies.pop(None, None) # Remove None when applying filters

return delayed(reduce)(merge_counts, [frequencies])
else:
input_data = read_data(
get_all_files_paths_under(input_data),
file_type=input_type,
backend="pandas",
add_filename=True,
)
delayed_counts = [
delayed(write_dataframe_by_meta)(
partition, output_dir, metadata_field, remove_metadata, output_type
partition,
output_dir,
metadata_field,
remove_metadata,
output_type,
include_values,
exclude_values,
)
for partition in delayed_data
for partition in input_data.to_delayed()
]
merged_counts = delayed(reduce)(merge_counts, delayed_counts)

return merged_counts
return delayed(reduce)(merge_counts, delayed_counts)


def parse_str_of_num_bytes(s, return_str=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_metadatasep(
add_filename=True,
).df
separate_by_metadata(
df=df,
input_data=df,
output_dir=str(output_dir),
metadata_field="metadata",
output_type=file_ext,
Expand Down

0 comments on commit 63d87c3

Please sign in to comment.