Skip to content

Commit

Permalink
[WIP] addressed comments in #193 apart from resolving .iloc pattern, …
Browse files Browse the repository at this point in the history
…test currently failing

Signed-off-by: Shuoyang Ding <shuoyangd@nvidia.com>
  • Loading branch information
shuoyangd committed Sep 20, 2024
1 parent 8a367dd commit 396d7ba
Show file tree
Hide file tree
Showing 15 changed files with 416 additions and 274 deletions.
5 changes: 3 additions & 2 deletions nemo_curator/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .doc_dataset import DocumentDataset, ParallelDataset
from .doc_dataset import DocumentDataset
from .parallel_dataset import ParallelDataset

__all__ = ["DocumentDataset"]
__all__ = ["DocumentDataset", "ParallelDataset"]
70 changes: 1 addition & 69 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

import dask.dataframe as dd

from nemo_curator.utils.distributed_utils import (
read_data,
read_simple_bitext_data,
write_to_disk,
)
from nemo_curator.utils.distributed_utils import read_data, write_to_disk
from nemo_curator.utils.file_utils import get_all_files_paths_under


Expand Down Expand Up @@ -256,67 +252,3 @@ def _read_json_or_parquet(
raise TypeError("File input must be a string or list.")

return raw_data


class ParallelDataset(DocumentDataset):
"""
An extension of the standard `DocumentDataset` with a special method that loads simple bitext.
For data with more complicated metadata, please convert your data into jsonl/parquet/pickle format
and use interfaces defined in `DocumentDataset`.
"""

def persist(self):
return ParallelDataset(self.df.persist())

@classmethod
def read_simple_bitext(
cls,
src_input_files: Union[str, List[str]],
tgt_input_files: Union[str, List[str]],
src_lang: str,
tgt_lang: str,
backend: str = "pandas",
add_filename: bool = False,
partition_size: Optional[Union[int, str]] = "100MB",
):
if isinstance(src_input_files, list) and isinstance(tgt_input_files, list):
df = read_simple_bitext_data(
src_input_files,
tgt_input_files,
src_lang,
tgt_lang,
backend,
add_filename,
)
elif isinstance(src_input_files, str) and isinstance(tgt_input_files, str):
df = read_simple_bitext_data(
[src_input_files],
[tgt_input_files],
src_lang,
tgt_lang,
backend,
add_filename,
)
else:
raise TypeError("Both file inputs must be strings or lists.")

# if partition_size:
# df = df.repartition(partition_size=partition_size)
return cls(df)

def to_bitext(
self,
output_file_dir,
write_to_filename=False,
):
"""
See nemo_curator.utils.distributed_utils.write_to_disk docstring for other parameters.
"""
write_to_disk(
df=self.df,
output_file_dir=output_file_dir,
write_to_filename=write_to_filename,
output_type="bitext",
)
157 changes: 157 additions & 0 deletions nemo_curator/datasets/parallel_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import csv
from typing import List, Optional, Tuple, Union

import dask.dataframe as dd
import pandas as pd

from nemo_curator.datasets.doc_dataset import DocumentDataset
from nemo_curator.utils.distributed_utils import write_to_disk
from nemo_curator.utils.file_utils import remove_path_extension
from nemo_curator.utils.import_utils import gpu_only_import

cudf = gpu_only_import("cudf")
dask_cudf = gpu_only_import("dask_cudf")


class ParallelDataset(DocumentDataset):
"""
An extension of the standard `DocumentDataset` with a special method that loads simple bitext.
For data with more complicated metadata, please convert your data into jsonl/parquet/pickle format
and use interfaces defined in `DocumentDataset`.
"""

def persist(self):
return ParallelDataset(self.df.persist())

@classmethod
def read_simple_bitext(
cls,
src_input_files: Union[str, List[str]],
tgt_input_files: Union[str, List[str]],
src_lang: str,
tgt_lang: str,
backend: str = "pandas",
add_filename: bool = False,
partition_size: Optional[Union[int, str]] = "100MB",
):
"""See `read_single_simple_bitext_file_pair` docstring for what "simple_bitext" means and usage of other parameters.
Args:
src_input_files (Union[str, List[str]]): one or several input files, in source language
tgt_input_files (Union[str, List[str]]): one or several input files, in target language
Raises:
TypeError: If types of `src_input_files` and `tgt_input_files` doesn't agree.
Returns:
ParallelDataset: A `ParallelDataset` object with `self.df` holding the ingested simple bitext.
"""

if isinstance(src_input_files, str) and isinstance(tgt_input_files, str):
src_input_files = [src_input_files]
tgt_input_files = [tgt_input_files]
elif not isinstance(src_input_files, list) or not isinstance(
tgt_input_files, list
):
raise TypeError("Both file inputs must be strings or lists.")

# TODO: use default doc id for now
# but it might be useful to allow customizing doc id by passing a prefix
df = dd.from_map(
ParallelDataset.read_single_simple_bitext_file_pair,
list(zip(src_input_files, tgt_input_files)),
src_lang=src_lang,
tgt_lang=tgt_lang,
backend=backend,
add_filename=add_filename,
)

# if partition_size:
# df = df.repartition(partition_size=partition_size)
return cls(df)

def to_bitext(
self,
output_file_dir,
write_to_filename=False,
):
"""See `nemo_curator.utils.distributed_utils.write_to_disk` docstring for parameter usage."""
write_to_disk(
df=self.df,
output_file_dir=output_file_dir,
write_to_filename=write_to_filename,
output_type="bitext",
)

@staticmethod
def read_single_simple_bitext_file_pair(
input_file_pair: Tuple[str],
src_lang: str,
tgt_lang: str,
doc_id: str = None,
backend: str = "cudf",
add_filename: bool = False,
) -> Union[dd.DataFrame, dask_cudf.DataFrame]:
"""This function reads a pair of "simple bitext" files into a pandas DataFrame.
A simple bitext is a commonly data format in machine translation.
It consists of two plain text files with the same number of lines, each line pair being translations of each other. For example:
data.de:
```
Wir besitzen keine Reisetaschen aus Leder.
Die Firma produziert Computer für den deutschen Markt.
...
```
data.en:
```
We don't own duffel bags made of leather.
The company produces computers for the German market.
...
```
For simplicity, we also assume that the names of the two text files have the same prefix, except for different language code at the end as file extensions.
Args:
input_file_pair (Tuple[str]): A pair of file paths pointing to the input files
src_lang (str): Source language, in ISO-639-1 (two character) format (e.g. 'en')
tgt_lang (str): Target language, in ISO-639-1 (two character) format (e.g. 'en')
doc_id (str, optional): A string document id to assign to every segment in the file. Defaults to None.
backend (str, optional): Backend of the data frame. Defaults to "cudf".
add_filename (bool, optional): Add filename as an extra field to every segment in the file. Defaults to False.
Returns:
Union[dd.DataFrame, dask_cudf.DataFrame]
"""
src_input_file, tgt_input_file = input_file_pair
assert remove_path_extension(src_input_file) == remove_path_extension(
tgt_input_file
), f"Assuming source and target filenames would have common prefix before language code, but got {src_input_file} and {tgt_input_file}."

if not doc_id:
doc_id = "▁".join([src_input_file, tgt_input_file])
df_combined["doc_id"] = doc_id

# TODO: it seems like cudf.read_table can only take one file max
# so maybe we shouldn't pass more than one
if backend == "cudf":
df = cudf
else:
df = pd

df_src = df.read_table(src_input_file, names=["src"], quoting=csv.QUOTE_NONE)
df_tgt = df.read_table(tgt_input_file, names=["tgt"], quoting=csv.QUOTE_NONE)
assert len(df_src) == len(
df_tgt
), f"We assume the source and target file would have the same number of lines, but got {len(df_src)} and {len(df_tgt)}."
df_combined = df.concat([df_src, df_tgt], axis=1)
df_combined["src_lang"] = src_lang
df_combined["tgt_lang"] = tgt_lang

if add_filename:
df_combined["filename"] = remove_path_extension(src_input_file)

return df_combined
2 changes: 2 additions & 0 deletions nemo_curator/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,6 @@
"HTMLBoilerplateFilter",
"PerExtensionFilter",
"LengthRatioFilter",
"HistogramFilter",
"QualityEstimationFilter",
]
45 changes: 43 additions & 2 deletions nemo_curator/filters/classifier_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from typing import List

import dask
import fasttext
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -105,6 +104,9 @@ def _load_model(self):


class QualityEstimationFilter(DocumentFilter):
"""(Bitext filter) Use a Quality Estimation (QE) model to score individual segments and filter based on estimated quality score.
(reference: https://arxiv.org/pdf/2311.05350)
"""

# a mapping from supported model names to their corresponding model class
SUPPORTED_MODELS = {
Expand All @@ -114,6 +116,15 @@ class QualityEstimationFilter(DocumentFilter):
}

def __init__(self, model_name, cutoff, mode="always_en_x", gpu=False):
"""Args:
model_name (_type_): Name of the model, as listed in the `SUPPORTED_MODELS` variable.
cutoff (_type_): A cut-off threshold for filtering. All segments with scores lower than this threshold will be tossed away.
mode (str, optional): See `_score_document_with_qe` for definition. Defaults to "always_en_x".
gpu (bool, optional): Whether to use GPU. Defaults to False.
Raises:
NotImplementedError: If a model name outside the supported model list is passed.
"""
if model_name in self.SUPPORTED_MODELS:
self._name = model_name
else:
Expand All @@ -129,6 +140,26 @@ def __init__(self, model_name, cutoff, mode="always_en_x", gpu=False):
def _score_document_with_qe(
self, model, df: pd.Series, mode="always_en_x"
) -> List[float]:
"""Arrange the documents according to the inference mode, call the model to estimate translation quality.
Args:
model (_type_): QE model object to be called.
df (pd.Series): Data frame that holds the translation data.
mode (str, optional): Currently three inference modes are supported:
- `simple`: Maintain the translation direction as specified in the data and
simply pass the corresponding fields to the quality estimation model.
- `always_en_x`: Always pass the English side as the source and non-English side as the target.
This is the strategy used by the referenced paper: https://arxiv.org/pdf/2311.05350.
- `bidi`: Estimate quality on both directions, then average the score. Potentially more accurate
when original translation direction is uncertain (note that "original" translation direction
might have been flipped while building the data), but also twice as expensive computationally.
Defaults to "always_en_x".
Returns:
List[float]: A list of float scores corresponding to the individual score of each documents.
"""

def _is_en_x(src_lang: str, tgt_lang: str):
return src_lang == "en" and tgt_lang != "en"
Expand Down Expand Up @@ -181,7 +212,16 @@ def _has_en(src_lang: str, tgt_lang: str):
raise NotImplementedError

@batched
def score_document(self, df: pd.Series):
def score_document(self, df: pd.Series) -> pd.Series:
"""Wrapper function that scores documents in a data frame. Most work is done in `_score_document_with_qe`.
Args:
df (pd.Series): Data frame that holds the translation data.
Returns:
pd.Series: A list of float scores corresponding to the individual score of each documents.
"""

model_attr = f"{self._name}_{self._model_path}"
try:
model = load_object_on_worker(
Expand All @@ -197,4 +237,5 @@ def score_document(self, df: pd.Series):
return pd.Series(scores, index=df.index)

def keep_document(self, score):
"""Decides whether a single document should be retained according to a threshold of estimated quality score."""
return score >= self._cutoff
Loading

0 comments on commit 396d7ba

Please sign in to comment.