Skip to content

Commit

Permalink
Filter unsupported extensions (huggingface#5972)
Browse files Browse the repository at this point in the history
* add filter_extensions

* test

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>

* keep zip archives for imagefolder and audiofolder

* minor

---------

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
  • Loading branch information
lhoestq and albertvillanova committed Jun 22, 2023
1 parent 79c340f commit 76f75a9
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 2 deletions.
19 changes: 19 additions & 0 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from functools import partial
from pathlib import Path, PurePath
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -748,6 +749,18 @@ def from_local_or_remote(
origin_metadata = _get_origin_metadata_locally_or_by_urls(data_files, use_auth_token=use_auth_token)
return cls(data_files, origin_metadata)

def filter_extensions(self, extensions: List[str]) -> "DataFilesList":
pattern = "|".join("\\" + ext for ext in extensions)
pattern = re.compile(f".*({pattern})(\\..+)?$")
return DataFilesList(
[
data_file
for data_file in self
if pattern.match(data_file.name if isinstance(data_file, Path) else data_file)
],
origin_metadata=self.origin_metadata,
)


class DataFilesDict(Dict[str, DataFilesList]):
"""
Expand Down Expand Up @@ -819,3 +832,9 @@ def __reduce__(self):
"""
return DataFilesDict, (dict(sorted(self.items())),)

def filter_extensions(self, extensions: List[str]) -> "DataFilesDict":
out = type(self)()
for key, data_files_list in self.items():
out[key] = data_files_list.filter_extensions(extensions)
return out
14 changes: 13 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .packaged_modules import (
_EXTENSION_TO_MODULE,
_MODULE_SUPPORTS_METADATA,
_MODULE_TO_EXTENSIONS,
_PACKAGED_DATASETS_MODULES,
_hash_python_lines,
)
Expand Down Expand Up @@ -347,6 +348,9 @@ def infer_module_for_data_files(
) -> Optional[Tuple[str, str]]:
"""Infer module (and builder kwargs) from list of data files.
It picks the module based on the most common file extension.
In case of a draw ".parquet" is the favorite, and then alphabetical order.
Args:
data_files_list (DataFilesList): List of data files.
use_auth_token (bool or str, optional): Whether to use token or token to authenticate on the Hugging Face Hub
Expand All @@ -363,7 +367,13 @@ def infer_module_for_data_files(
for suffix in Path(filepath).suffixes
)
if extensions_counter:
for ext, _ in extensions_counter.most_common():

def sort_key(ext_count: Tuple[str, int]) -> Tuple[int, bool]:
"""Sort by count and set ".parquet" as the favorite in case of a draw"""
ext, count = ext_count
return (count, ext == ".parquet", ext)

for ext, _ in sorted(extensions_counter.items(), key=sort_key, reverse=True):
if ext in _EXTENSION_TO_MODULE:
return _EXTENSION_TO_MODULE[ext]
elif ext == ".zip":
Expand Down Expand Up @@ -642,6 +652,7 @@ def get_module(self) -> DatasetModule:
raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}")
if not module_name:
raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.path}")
data_files = data_files.filter_extensions(_MODULE_TO_EXTENSIONS[module_name])
# Collect metadata files if the module supports them
if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL:
try:
Expand Down Expand Up @@ -783,6 +794,7 @@ def get_module(self) -> DatasetModule:
raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}")
if not module_name:
raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.name}")
data_files = data_files.filter_extensions(_MODULE_TO_EXTENSIONS[module_name])
# Collect metadata files if the module supports them
if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL:
try:
Expand Down
11 changes: 10 additions & 1 deletion src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import re
from hashlib import sha256
from typing import List
from typing import Dict, List

from .arrow import arrow
from .audiofolder import audiofolder
Expand Down Expand Up @@ -39,6 +39,7 @@ def _hash_python_lines(lines: List[str]) -> str:
"audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())),
}

# Used to infer the module to use based on the data files extensions
_EXTENSION_TO_MODULE = {
".csv": ("csv", {}),
".tsv": ("csv", {"sep": "\t"}),
Expand All @@ -53,3 +54,11 @@ def _hash_python_lines(lines: List[str]) -> str:
_EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"}

# Used to filter data files based on extensions given a module name
_MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {}
for _ext, (_module, _) in _EXTENSION_TO_MODULE.items():
_MODULE_TO_EXTENSIONS.setdefault(_module, []).append(_ext)

_MODULE_TO_EXTENSIONS["imagefolder"].append(".zip")
_MODULE_TO_EXTENSIONS["audiofolder"].append(".zip")
10 changes: 10 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,16 @@ def text2_path(tmp_path_factory):
return path


@pytest.fixture(scope="session")
def text_dir_with_unsupported_extension(tmp_path_factory):
data = ["0", "1", "2", "3"]
path = tmp_path_factory.mktemp("data") / "dataset.abc"
with open(path, "w") as f:
for item in data:
f.write(item + "\n")
return path


@pytest.fixture(scope="session")
def zip_text_path(text_path, text2_path, tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "dataset.text.zip"
Expand Down
6 changes: 6 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,12 @@ def test_load_dataset_text_with_unicode_new_lines(text_path_with_unicode_new_lin
assert ds.num_rows == 3


def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_extension):
data_files = str(text_dir_with_unsupported_extension)
ds = load_dataset("text", split="train", data_files=data_files)
assert ds.num_rows == 4


@pytest.mark.integration
def test_loading_from_the_datasets_hub():
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down

0 comments on commit 76f75a9

Please sign in to comment.