Skip to content

Commit

Permalink
Fix metadata inference with pandas and dask (NVIDIA#35)
Browse files Browse the repository at this point in the history
* Fix metadata inference with pandas and dask

Signed-off-by: Ryan Wolf <rywolf@nvidia.com>

* Fix datatypes for task decontamination

Signed-off-by: Ryan Wolf <rywolf@nvidia.com>

* Use targetted import

Signed-off-by: Ryan Wolf <rywolf@nvidia.com>

---------

Signed-off-by: Ryan Wolf <rywolf@nvidia.com>
Signed-off-by: Nicole Luo <nluo@nvidia.com>
  • Loading branch information
ryantwolf authored and nicoleeeluo committed May 20, 2024
1 parent ec26f9f commit 462a1a3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
9 changes: 8 additions & 1 deletion nemo_curator/modules/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
from dask.dataframe.extensions import make_array_nonempty
from dask.typing import no_default

from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.module_utils import is_batched

# Override so that pd.NA is not passed during the metadata inference
make_array_nonempty.register(
pd.StringDtype,
lambda x: pd.array(["a", "b"], dtype=x),
)


class Score:
def __init__(self, score_fn, score_field, text_field="text", score_type=None):
Expand Down
12 changes: 11 additions & 1 deletion nemo_curator/modules/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def _threshold_ngram_count(self, matched_ngrams: dict) -> set:
return filtered_ngrams

def _remove_ngrams_partition(self, partition, task_ngrams, ngrams_freq_sorted):
text_type = partition[self.text_field].dtype

document_fn = partial(
self._remove_ngrams,
task_ngrams=task_ngrams,
Expand All @@ -318,7 +320,15 @@ def _remove_ngrams_partition(self, partition, task_ngrams, ngrams_freq_sorted):

partition[self.text_field] = split_text
filtered_partition = partition[valid_documents_mask]
return filtered_partition.explode(self.text_field, ignore_index=True)
exploded_partition = filtered_partition.explode(
self.text_field, ignore_index=True
)
# After exploding, the string datatype can become an "object" type
exploded_partition[self.text_field] = exploded_partition[
self.text_field
].astype(text_type)

return exploded_partition

def _remove_ngrams(self, document, task_ngrams, ngrams_freq_sorted):
"""
Expand Down

0 comments on commit 462a1a3

Please sign in to comment.