Skip to content

Commit

Permalink
Add IterableDataset.filter (huggingface#3826)
Browse files Browse the repository at this point in the history
* update instead of overwrite during streaming map + add with_indices, input_columns, remove_columns

* style

* tests

* update docs

* docs

* more docs

* add IterableDataset.filter

* tests

* docs

* docs
  • Loading branch information
lhoestq committed Mar 9, 2022
1 parent 0e6ab17 commit ba0aab0
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ The base class [`datasets.IterableDataset`] implements an iterable Dataset backe
- cast_column
- __iter__
- map
- filter
- shuffle
- skip
- take
Expand Down
2 changes: 1 addition & 1 deletion docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ There are two options for filtering rows in a dataset: [`datasets.Dataset.select
[`datasets.Dataset.filter`] can also filter by indices if you set `with_indices=True`:

```py
>>> even_dataset = dataset.filter(lambda example, indice: indice % 2 == 0, with_indices=True)
>>> even_dataset = dataset.filter(lambda example, idx: idx % 2 == 0, with_indices=True)
>>> len(even_dataset)
1834
>>> len(dataset) / 2
Expand Down
22 changes: 22 additions & 0 deletions docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,28 @@ See other examples of batch processing in [the batched map processing documentat

</Tip>

### Filter

You can filter rows in the dataset based on a predicate function using [`datasets.Dataset.filter`]. It returns rows that match a specified condition:

```py
>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', 'unshuffled_deduplicated_en', streaming=True, split='train')
>>> start_with_ar = dataset.filter(lambda example: example['text'].startswith('Ar'))
>>> next(iter(start_with_ar))
{'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)?...'}
```

[`datasets.Dataset.filter`] can also filter by indices if you set `with_indices=True`:

```py
>>> even_dataset = dataset.filter(lambda example, idx: idx % 2 == 0, with_indices=True)
>>> list(even_dataset.take(3))
[{'id': 0, 'text': 'Mtendere Village was inspired by the vision of Chief Napoleon Dzombe, ...'},
{'id': 2, 'text': '"I\'d love to help kickstart continued development! And 0 EUR/month...'},
{'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)? Normally, ...'}]
```

## Stream in a training loop

[`datasets.IterableDataset`] can be integrated into a training loop. First, shuffle the dataset:
Expand Down
121 changes: 120 additions & 1 deletion src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,73 @@ def n_shards(self) -> int:
return self.ex_iterable.n_shards


class FilteredExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
ex_iterable: _BaseExamplesIterable,
function: Callable,
with_indices: bool = False,
input_columns: Optional[List[str]] = None,
batched: bool = False,
batch_size: int = 1000,
):
self.ex_iterable = ex_iterable
self.function = function
self.batched = batched
self.batch_size = batch_size
self.with_indices = with_indices
self.input_columns = input_columns

def __iter__(self):
iterator = iter(self.ex_iterable)
current_idx = 0
if self.batched:
for key, example in iterator:
# If batched, first build the batch
key_examples_list = [(key, example)] + [
(key, example) for key, example in islice(iterator, self.batch_size - 1)
]
keys, examples = zip(*key_examples_list)
batch = _examples_to_batch(examples)
# then compute the mask for the batch
inputs = batch
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
if self.with_indices:
function_args.append([current_idx + i for i in range(len(key_examples_list))])
mask = self.function(*function_args)
# yield one example at a time from the batch
for batch_idx, (key_example, to_keep) in enumerate(zip(key_examples_list, mask)):
if to_keep:
yield key_example
current_idx += batch_idx + 1
else:
for key, example in iterator:
# If not batched, we can apply the filtering function direcly
inputs = dict(example)
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
if self.with_indices:
function_args.append(current_idx)
to_keep = self.function(*function_args)
if to_keep:
yield key, example
current_idx += 1

def shuffle_data_sources(self, seed: Optional[int]) -> "MappedExamplesIterable":
"""Shuffle the wrapped examples iterable."""
return FilteredExamplesIterable(
self.ex_iterable.shuffle_data_sources(seed),
function=self.function,
with_indices=self.with_indices,
input_columns=self.input_columns,
batched=self.batched,
batch_size=self.batch_size,
)

@property
def n_shards(self) -> int:
return self.ex_iterable.n_shards


class BufferShuffledExamplesIterable(_BaseExamplesIterable):
def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator):
self.ex_iterable = ex_iterable
Expand Down Expand Up @@ -469,7 +536,7 @@ def map(
batched: bool = False,
batch_size: int = 1000,
remove_columns: Optional[Union[str, List[str]]] = None,
):
) -> "IterableDataset":
"""
Apply a function to all the examples in the iterable dataset (individually or in batches) and update them.
If your function returns a column that already exists, then it overwrites it.
Expand Down Expand Up @@ -522,6 +589,58 @@ def map(
shuffling=copy.deepcopy(self._shuffling),
)

def filter(
self,
function: Optional[Callable] = None,
with_indices=False,
input_columns: Optional[Union[str, List[str]]] = None,
batched: bool = False,
batch_size: Optional[int] = 1000,
) -> "IterableDataset":
"""Apply a filter function to all the elements so that the dataset only includes examples according to the filter function.
The filtering is done on-the-fly when iterating over the dataset.
Args:
function (:obj:`Callable`): Callable with one of the following signatures:
- ``function(example: Union[Dict, Any]) -> bool`` if ``with_indices=False, batched=False``
- ``function(example: Union[Dict, Any], indices: int) -> bool`` if ``with_indices=True, batched=False``
- ``function(example: Union[Dict, Any]) -> List[bool]`` if ``with_indices=False, batched=True``
- ``function(example: Union[Dict, Any], indices: int) -> List[bool]`` if ``with_indices=True, batched=True``
If no function is provided, defaults to an always True function: ``lambda x: True``.
with_indices (:obj:`bool`, default `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`.
input_columns (:obj:`str` or `List[str]`, optional): The columns to be passed into `function` as
positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument.
batched (:obj:`bool`, defaults to `False`): Provide batch of examples to `function`
batch_size (:obj:`int`, optional, default ``1000``): Number of examples per batch provided to `function` if `batched=True`.
"""
if isinstance(input_columns, str):
input_columns = [input_columns]

# TODO(QL): keep the features (right now if we keep it it would call decode_example again on an already decoded example)
info = copy.deepcopy(self._info)
info.features = None

# We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here
ex_iterable = FilteredExamplesIterable(
TypedExamplesIterable(self._ex_iterable, self._info.features)
if self._info.features is not None
else self._ex_iterable,
function=function,
with_indices=with_indices,
input_columns=input_columns,
batched=batched,
batch_size=batch_size,
)
return iterable_dataset(
ex_iterable=ex_iterable,
info=info,
split=self._split,
format_type=self._format_type,
shuffling=copy.deepcopy(self._shuffling),
)

def shuffle(
self, seed=None, generator: Optional[np.random.Generator] = None, buffer_size: int = 1000
) -> "IterableDataset":
Expand Down
89 changes: 89 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BufferShuffledExamplesIterable,
CyclingMultiSourcesExamplesIterable,
ExamplesIterable,
FilteredExamplesIterable,
IterableDataset,
MappedExamplesIterable,
RandomlyCyclingMultiSourcesExamplesIterable,
Expand Down Expand Up @@ -301,6 +302,94 @@ def test_mapped_examples_iterable_input_columns(generate_examples_fn, n, func, b
assert list(x for _, x in ex_iterable) == expected


@pytest.mark.parametrize(
"n, func, batch_size",
[
(3, lambda x: x["id"] % 2 == 0, None), # keep even number
(3, lambda x: [x["id"][0] % 2 == 0], 1), # same with bs=1
(5, lambda x: [i % 2 == 0 for i in x["id"]], 10), # same with bs=10
(25, lambda x: [i % 2 == 0 for i in x["id"]], 10), # same with bs=10
(3, lambda x: False, None), # return 0 examples
(3, lambda x: [False] * len(x["id"]), 10), # same with bs=10
],
)
def test_filtered_examples_iterable(generate_examples_fn, n, func, batch_size):
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
ex_iterable = FilteredExamplesIterable(
base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size
)
all_examples = [x for _, x in generate_examples_fn(n=n)]
if batch_size is None:
expected = [x for x in all_examples if func(x)]
else:
# For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
expected = []
for batch_offset in range(0, len(all_examples), batch_size):
examples = all_examples[batch_offset : batch_offset + batch_size]
batch = _examples_to_batch(examples)
mask = func(batch)
expected.extend([x for x, to_keep in zip(examples, mask) if to_keep])
if expected:
assert next(iter(ex_iterable))[1] == expected[0]
assert list(x for _, x in ex_iterable) == expected


@pytest.mark.parametrize(
"n, func, batch_size",
[
(3, lambda x, index: index % 2 == 0, None), # keep even number
(25, lambda x, indices: [idx % 2 == 0 for idx in indices], 10), # same with bs=10
],
)
def test_filtered_examples_iterable_with_indices(generate_examples_fn, n, func, batch_size):
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
ex_iterable = FilteredExamplesIterable(
base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, with_indices=True
)
all_examples = [x for _, x in generate_examples_fn(n=n)]
if batch_size is None:
expected = [x for idx, x in enumerate(all_examples) if func(x, idx)]
else:
# For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
expected = []
for batch_offset in range(0, len(all_examples), batch_size):
examples = all_examples[batch_offset : batch_offset + batch_size]
batch = _examples_to_batch(examples)
indices = list(range(batch_offset, batch_offset + len(examples)))
mask = func(batch, indices)
expected.extend([x for x, to_keep in zip(examples, mask) if to_keep])
assert next(iter(ex_iterable))[1] == expected[0]
assert list(x for _, x in ex_iterable) == expected


@pytest.mark.parametrize(
"n, func, batch_size, input_columns",
[
(3, lambda id_: id_ % 2 == 0, None, ["id"]), # keep even number
(25, lambda ids_: [i % 2 == 0 for i in ids_], 10, ["id"]), # same with bs=10
],
)
def test_filtered_examples_iterable_input_columns(generate_examples_fn, n, func, batch_size, input_columns):
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
ex_iterable = FilteredExamplesIterable(
base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, input_columns=input_columns
)
all_examples = [x for _, x in generate_examples_fn(n=n)]
columns_to_input = input_columns if isinstance(input_columns, list) else [input_columns]
if batch_size is None:
expected = [x for x in all_examples if func(*[x[col] for col in columns_to_input])]
else:
# For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
expected = []
for batch_offset in range(0, len(all_examples), batch_size):
examples = all_examples[batch_offset : batch_offset + batch_size]
batch = _examples_to_batch(examples)
mask = func(*[batch[col] for col in columns_to_input])
expected.extend([x for x, to_keep in zip(examples, mask) if to_keep])
assert next(iter(ex_iterable))[1] == expected[0]
assert list(x for _, x in ex_iterable) == expected


def test_skip_examples_iterable(generate_examples_fn):
total, count = 10, 2
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": total})
Expand Down

0 comments on commit ba0aab0

Please sign in to comment.