Skip to content

Commit

Permalink
Improve error handling in make_dataset (pytorch#3496)
Browse files Browse the repository at this point in the history
* factor out find_classes

* use find_classes in video datasets

* adapt old tests
  • Loading branch information
pmeier authored Mar 24, 2021
1 parent 19ad0bb commit 0818c68
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 42 deletions.
7 changes: 2 additions & 5 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def test_imagefolder(self):

def test_imagefolder_empty(self):
with get_tmp_dir() as root:
with self.assertRaises(RuntimeError):
with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(root, loader=lambda x: x)

with self.assertRaises(RuntimeError):
with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: False
)
Expand Down Expand Up @@ -1092,9 +1092,6 @@ def inject_fake_data(self, tmpdir, config):

return num_videos_per_class * len(classes)

def test_not_found_or_corrupted(self):
self.skipTest("Dataset currently does not handle the case of no found videos.")


class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51
Expand Down
90 changes: 65 additions & 25 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,52 @@ def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS)


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset structured as follows:
.. code::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
Args:
directory (str): Root directory path.
Raises:
FileNotFoundError: If ``directory`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx


def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated
by :func:`find_classes`.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
Expand All @@ -51,21 +86,34 @@ def make_dataset(
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)

if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

if extensions is not None:

def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

is_valid_file = cast(Callable[[str], bool], is_valid_file)

instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
Expand All @@ -77,6 +125,17 @@ def is_valid_file(x: str) -> bool:
if is_valid_file(path):
item = path, class_index
instances.append(item)

if target_class not in available_classes:
available_classes.add(target_class)

empty_classes = available_classes - set(class_to_idx.keys())
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)

return instances


Expand Down Expand Up @@ -125,11 +184,6 @@ def __init__(
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)

self.loader = loader
self.extensions = extensions
Expand All @@ -148,23 +202,9 @@ def make_dataset(
) -> List[Tuple[str, int]]:
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
@staticmethod
def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]:
return find_classes(dir)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Expand Down
6 changes: 2 additions & 4 deletions torchvision/datasets/hmdb51.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset

Expand Down Expand Up @@ -62,8 +62,7 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
raise ValueError("fold should be between 1 and 3, got {}".format(fold))

extensions = ('avi',)
classes = sorted(list_dir(root))
class_to_idx = {class_: i for (i, class_) in enumerate(classes)}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(
self.root,
class_to_idx,
Expand All @@ -89,7 +88,6 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
self.full_video_clips = video_clips
self.fold = fold
self.train = train
self.classes = classes
self.indices = self._select_fold(video_paths, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
Expand Down
6 changes: 2 additions & 4 deletions torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset

Expand Down Expand Up @@ -56,10 +56,8 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
_video_min_dimension=0, _audio_samples=0, _audio_channels=0):
super(Kinetics400, self).__init__(root)

classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips(
video_list,
Expand Down
6 changes: 2 additions & 4 deletions torchvision/datasets/ucf101.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset

Expand Down Expand Up @@ -55,10 +55,8 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
self.fold = fold
self.train = train

classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
video_clips = VideoClips(
video_list,
Expand Down

0 comments on commit 0818c68

Please sign in to comment.