Skip to content

Commit

Permalink
Improve the DATA_LIMIT atribute to be able to handle more uses cases (f…
Browse files Browse the repository at this point in the history
…acebookresearch#216)

Summary:
This PR is a draft, pushed for visibility and discussion.

The additional uses cases I propose to support are:
- being able to sub-select part of a dataset in a balanced way (each label is included the same number of time)
- being able to sub-select exclusive parts of the same dataset (for instance to have a validation set that does not intersect with a training set, useful for HP searches)
- make sure that this sub-sampling is deterministic (same seed across all distributed workers)

This would avoid having to create sub-sets of datasets such as ImageNet to test on 1% of each label for instance. It would also allow to benchmark SSL algorithms on low data regime in a more flexible way.

/!\ This PR introduces a breaking change (DATA_LIMIT is not an integer anymore but a structure)

This PR includes:
- unit tests for the sub-sampling strategies
- update of all configuration using the DATA_LIMIT attribute

Pull Request resolved: facebookresearch#216

Reviewed By: prigoyal

Differential Revision: D26923493

Pulled By: QuentinDuval

fbshipit-source-id: b4ed7c61369587ac9349218933b5eed357c19b06
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Mar 23, 2021
1 parent a67277c commit 930e0c9
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 33 deletions.
51 changes: 51 additions & 0 deletions tests/test_data_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest

import numpy as np
from vissl.data.data_helper import balanced_sub_sampling, unbalanced_sub_sampling


class TestDataLimitSubSampling(unittest.TestCase):
"""
Testing the DATA_LIMIT underlying sub sampling methods
"""

def test_unbalanced_sub_sampling(self):
labels = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 0])

indices1 = unbalanced_sub_sampling(len(labels), num_samples=8, skip_samples=0)
self.assertEqual(8, len(indices1))
self.assertEqual(len(indices1), len(set(indices1)), "indices must be unique")

indices2 = unbalanced_sub_sampling(len(labels), num_samples=8, skip_samples=2)
self.assertEqual(8, len(indices2))
self.assertEqual(len(indices2), len(set(indices2)), "indices must be unique")

self.assertTrue(
np.array_equal(indices1[2:], indices2[:-2]),
"skipping samples should slide the window",
)

def test_balanced_sub_sampling(self):
labels = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 0])
unique_labels = set(labels)

indices1 = balanced_sub_sampling(labels, num_samples=8, skip_samples=0)
values, counts = np.unique(labels[indices1], return_counts=True)
self.assertEqual(8, len(indices1))
self.assertEqual(
set(values),
set(unique_labels),
"at least one of each label should be selected",
)
self.assertEqual(2, np.min(counts), "at least two of each label is selected")
self.assertEqual(2, np.max(counts), "at most two of each label is selected")

indices2 = balanced_sub_sampling(labels, num_samples=8, skip_samples=4)
self.assertEqual(8, len(indices2))
self.assertEqual(
4,
len(set(indices1) & set(indices2)),
"skipping samples should slide the window",
)
23 changes: 22 additions & 1 deletion vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,25 @@ config:
# accomodate the smoothed labels. See
# LOSS.cross_entropy_multiple_output_single_target for more information.
#
# limit the amount of data used in training. If set to -1, full dataset is used.
# Limit the amount of data used in training. If set to -1, full dataset is used.
#
DATA_LIMIT: -1
#
# Specifies how the DATA_LIMIT samples are sampled
#
# Example: to select a range of 500 samples for validation, skipping the first 1000 samples (say these are
# already used in the training split) and sub-sampling these elements such that each class appears equally:
# DATA_LIMIT: 500
# DATA_LIMIT_SAMPLING:
# SEED: 0
# IS_BALANCED: True
# SKIP_NUM_SAMPLES: 1000
#
DATA_LIMIT_SAMPLING:
SEED: 0
IS_BALANCED: False
SKIP_NUM_SAMPLES: 0

# whether the data specified (whether file list or directory) should be copied locally
# on the machine where training is happening.
COPY_TO_LOCAL_DISK: False
Expand Down Expand Up @@ -286,6 +303,10 @@ config:
COLLATE_FUNCTION: "default_collate"
COLLATE_FUNCTION_PARAMS: {}
DATA_LIMIT: -1
DATA_LIMIT_SAMPLING:
SEED: 0
IS_BALANCED: False
SKIP_NUM_SAMPLES: 0
DATASET_NAMES: ["imagenet1k_folder"]
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: ""
Expand Down
66 changes: 64 additions & 2 deletions vissl/data/data_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import contextlib
import logging
import queue

Expand All @@ -24,6 +24,67 @@ def get_mean_image(crop_size):
return img


@contextlib.contextmanager
def with_temporary_numpy_seed(sampling_seed: int):
"""
Context manager to run a specific portion of code with a given seed:
resumes the numpy state after the execution of the block
"""
original_random_state = np.random.get_state()
np.random.seed(sampling_seed)
yield
np.random.set_state(original_random_state)


def unbalanced_sub_sampling(
total_num_samples: int, num_samples: int, skip_samples: int = 0, seed: int = 0
) -> np.ndarray:
"""
Given an original dataset of size 'total_size', sub_sample part of the dataset such that
the sub sampling is deterministic (identical across distributed workers)
Return the selected indices
"""
with with_temporary_numpy_seed(seed):
return np.random.choice(
total_num_samples, size=skip_samples + num_samples, replace=False
)[skip_samples:]


def balanced_sub_sampling(
labels: np.ndarray, num_samples: int, skip_samples: int = 0, seed: int = 0
) -> np.ndarray:
"""
Given all the labels of a dataset, sub_sample a part of the labels such that:
- the number of samples of each label differs by at most one
- the sub sampling is deterministic (identical across distributed workers)
Return the indices of the selected labels
"""
groups = {}
for i, label in enumerate(labels):
groups.setdefault(label, []).append(i)

unique_labels = sorted(groups.keys())
skip_quotient, skip_rest = divmod(skip_samples, len(unique_labels))
sample_quotient, sample_rest = divmod(num_samples, len(unique_labels))
assert (
sample_quotient > 0
), "the number of samples should be at least equal to the number of classes"

with with_temporary_numpy_seed(seed):
for i, label in enumerate(unique_labels):
label_indices = groups[label]
num_label_samples = sample_quotient + (1 if i < sample_rest else 0)
skip_label_samples = skip_quotient + (1 if i < skip_rest else 0)
permuted_indices = np.random.choice(
label_indices,
size=skip_label_samples + num_label_samples,
replace=False,
)
groups[label] = permuted_indices[skip_label_samples:]

return np.concatenate([groups[label] for label in unique_labels])


class StatefulDistributedSampler(DistributedSampler):
"""
More fine-grained state DataSampler that uses training iteration and epoch
Expand Down Expand Up @@ -70,7 +131,8 @@ def __iter__(self):
assert self.batch_size > 0, "batch_size not set for the sampler"

# resume the sampler
indices = indices[(self.start_iter * self.batch_size) :]
start_index = self.start_iter * self.batch_size
indices = indices[start_index:]
return iter(indices)

def set_start_iter(self, start_iter):
Expand Down
7 changes: 0 additions & 7 deletions vissl/data/disk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ def _load_data(self, path):
# Avoid creating it over and over again.
self.is_initialized = True

if self.cfg["DATA"][self.split]["DATA_LIMIT"] > 0:
limit = self.cfg["DATA"][self.split]["DATA_LIMIT"]
if self.data_source == "disk_filelist":
self.image_dataset = self.image_dataset[:limit]
elif self.data_source == "disk_folder":
self.image_dataset.samples = self.image_dataset.samples[:limit]

def num_samples(self):
"""
Size of the dataset
Expand Down
102 changes: 84 additions & 18 deletions vissl/data/ssl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from fvcore.common.file_io import PathManager
from torch.utils.data import Dataset
from vissl.data import dataset_catalog
from vissl.data.data_helper import balanced_sub_sampling, unbalanced_sub_sampling
from vissl.data.ssl_transforms import get_transform
from vissl.utils.env import get_machine_local_and_dist_rank
from vissl.utils.hydra_config import AttrDict


def _convert_lbl_to_long(lbl):
Expand Down Expand Up @@ -60,7 +62,7 @@ class GenericSSLDataset(Dataset):
}
"""

def __init__(self, cfg, split, dataset_source_map):
def __init__(self, cfg: AttrDict, split: str, dataset_source_map):
self.split = split
self.cfg = cfg
self.data_objs = []
Expand All @@ -72,8 +74,12 @@ def __init__(self, cfg, split, dataset_source_map):
self.label_sources = self.cfg["DATA"][split].LABEL_SOURCES
self.dataset_names = self.cfg["DATA"][split].DATASET_NAMES
self.label_type = self.cfg["DATA"][split].LABEL_TYPE
self.data_limit = self.cfg["DATA"][split].DATA_LIMIT
self.data_limit_sampling = self._get_data_limit_sampling(cfg, split)
self.transform = get_transform(self.cfg["DATA"][split].TRANSFORMS)
self._labels_init = False
self._subset_initialized = False
self.image_and_label_subset = None
self._verify_data_sources(split, dataset_source_map)
self._get_data_files(split)

Expand All @@ -95,6 +101,13 @@ def __init__(self, cfg, split, dataset_source_map):
)
)

@staticmethod
def _get_data_limit_sampling(cfg: AttrDict, split: str) -> AttrDict:
default_sampling = AttrDict(
{"SEED": 0, "IS_BALANCED": False, "SKIP_NUM_SAMPLES": 0}
)
return cfg["DATA"][split].get("DATA_LIMIT_SAMPLING", default_sampling)

def _verify_data_sources(self, split, dataset_source_map):
"""
For each data source, verify that the specified data source
Expand Down Expand Up @@ -221,7 +234,54 @@ def _load_labels(self):
raise ValueError(f"unknown label source: {label_source}")
self.label_objs.append(labels)

def __getitem__(self, idx):
def _can_random_subset_data_sources(self):
"""
Backward compatibility: some plug-in data sources do have an internal
support for data_limit, and we keep the same behavior here (we ignore
the DATA_LIMIT attribute in GenericSSLDataset)
"""
valid_datasets = {
"disk_filelist",
"disk_folder",
"torchvision_dataset",
"synthetic",
}
return all(source in valid_datasets for source in self.data_sources)

def _init_image_and_label_subset(self):
"""
If DATA_LIMIT = K >= 0, we reduce the size of the dataset from N to K.
This function will create a mapping from [0, K) to [0, N), using the
parameters specified in the DATA_LIMIT_SAMPLING configuration. This
mapping is then cached and used for all __getitem__ calls to map
the external indices from [0, K) to the internal [0, N) indices.
This function makes the assumption that there is one data source only
or that all data sources have the same length (same as __getitem__).
"""

# Use one of the two random sampling strategies:
# - unbalanced: random sampling is agnostic to labels
# - balanced: makes sure all labels are equally represented
if not self.data_limit_sampling.IS_BALANCED:
self.image_and_label_subset = unbalanced_sub_sampling(
total_num_samples=len(self.data_objs[0]),
num_samples=self.data_limit,
skip_samples=self.data_limit_sampling.SKIP_NUM_SAMPLES,
seed=self.data_limit_sampling.SEED,
)
else:
assert len(self.label_objs), "Balanced sampling requires labels"
self.image_and_label_subset = balanced_sub_sampling(
labels=self.label_objs[0],
num_samples=self.data_limit,
skip_samples=self.data_limit_sampling.SKIP_NUM_SAMPLES,
seed=self.data_limit_sampling.SEED,
)
self._subset_initialized = True

def __getitem__(self, idx: int):
"""
Get the input sample for the minibatch for a specified data index.
For each data object (if we are loading several datasets in a minibatch),
Expand All @@ -241,11 +301,17 @@ def __getitem__(self, idx):
self._load_labels()
self._labels_init = True

subset_idx = idx
if self.data_limit >= 0 and self._can_random_subset_data_sources():
if not self._subset_initialized:
self._init_image_and_label_subset()
subset_idx = self.image_and_label_subset[idx]

# TODO: this doesn't yet handle the case where the length of datasets
# could be different.
item = {"data": [], "data_valid": [], "data_idx": []}
for source in self.data_objs:
data, valid = source[idx]
for data_source in self.data_objs:
data, valid = data_source[subset_idx]
item["data"].append(data)
item["data_idx"].append(idx)
item["data_valid"].append(1 if valid else -1)
Expand All @@ -261,11 +327,11 @@ def __getitem__(self, idx):
# to its functionality.
if (len(self.label_objs) > 0) or self.label_type == "standard":
item["label"] = []
for source in self.label_objs:
if isinstance(source, list):
lbl = [entry[idx] for entry in source]
for label_source in self.label_objs:
if isinstance(label_source, list):
lbl = [entry[subset_idx] for entry in label_source]
else:
lbl = _convert_lbl_to_long(source[idx])
lbl = _convert_lbl_to_long(label_source[subset_idx])
item["label"].append(lbl)
elif self.label_type == "sample_index":
item["label"] = []
Expand All @@ -285,10 +351,17 @@ def __getitem__(self, idx):

def __len__(self):
"""
Size of the dataset. Assumption made there is only one
data source
Size of the dataset. Assumption made there is only one data source
"""
return len(self.data_objs[0])
return self.num_samples(0)

def num_samples(self, source_idx=0):
"""
Size of the dataset. Assumption made there is only one data source
"""
if self.data_limit >= 0:
return self.data_limit
return len(self.data_objs[source_idx])

def get_image_paths(self):
"""
Expand All @@ -312,13 +385,6 @@ def get_available_splits(self, dataset_config):
"""
return [key for key in dataset_config if key.lower() in ["train", "test"]]

def num_samples(self, source_idx=0):
"""
Size of the dataset. Assumption made there is only one
data source
"""
return len(self.data_objs[source_idx])

def get_batchsize_per_replica(self):
"""
Get the batch size per trainer
Expand Down
11 changes: 6 additions & 5 deletions vissl/data/synthetic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ class SyntheticImageDataset(Dataset):
data_source (string, Optional): data source ("synthetic") [not used]
"""

def __init__(self, cfg, path, split, dataset_name, data_source="synthetic"):
DEFAULT_SIZE = 50_000

def __init__(
self, cfg, path: str, split: str, dataset_name: str, data_source="synthetic"
):
super(SyntheticImageDataset, self).__init__()
self.cfg = cfg
self.split = split
self.data_source = data_source
self._num_samples = 50000
# by default, pretend dataset size is 500 images. OR user specified limit
if cfg.DATA[split].DATA_LIMIT > 0:
self._num_samples = cfg.DATA[split].DATA_LIMIT
self._num_samples = max(self.DEFAULT_SIZE, cfg.DATA[split].DATA_LIMIT)

def num_samples(self):
"""
Expand Down

0 comments on commit 930e0c9

Please sign in to comment.