diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 36a76a5b57d..8915dec13ef 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -152,6 +152,7 @@ The base class [`datasets.IterableDataset`] implements an iterable Dataset backe - cast_column - __iter__ - map + - filter - shuffle - skip - take diff --git a/docs/source/process.mdx b/docs/source/process.mdx index b6e228b6b17..5eb535f6bd3 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -83,7 +83,7 @@ There are two options for filtering rows in a dataset: [`datasets.Dataset.select [`datasets.Dataset.filter`] can also filter by indices if you set `with_indices=True`: ```py ->>> even_dataset = dataset.filter(lambda example, indice: indice % 2 == 0, with_indices=True) +>>> even_dataset = dataset.filter(lambda example, idx: idx % 2 == 0, with_indices=True) >>> len(even_dataset) 1834 >>> len(dataset) / 2 diff --git a/docs/source/stream.mdx b/docs/source/stream.mdx index 7a27ec8e03b..21240157a0b 100644 --- a/docs/source/stream.mdx +++ b/docs/source/stream.mdx @@ -184,6 +184,28 @@ See other examples of batch processing in [the batched map processing documentat +### Filter + +You can filter rows in the dataset based on a predicate function using [`datasets.Dataset.filter`]. It returns rows that match a specified condition: + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset('oscar', 'unshuffled_deduplicated_en', streaming=True, split='train') +>>> start_with_ar = dataset.filter(lambda example: example['text'].startswith('Ar')) +>>> next(iter(start_with_ar)) +{'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)?...'} +``` + +[`datasets.Dataset.filter`] can also filter by indices if you set `with_indices=True`: + +```py +>>> even_dataset = dataset.filter(lambda example, idx: idx % 2 == 0, with_indices=True) +>>> list(even_dataset.take(3)) +[{'id': 0, 'text': 'Mtendere Village was inspired by the vision of Chief Napoleon Dzombe, ...'}, + {'id': 2, 'text': '"I\'d love to help kickstart continued development! And 0 EUR/month...'}, + {'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)? Normally, ...'}] +``` + ## Stream in a training loop [`datasets.IterableDataset`] can be integrated into a training loop. First, shuffle the dataset: diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index b255f9d7c79..1b75f17a623 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -264,6 +264,73 @@ def n_shards(self) -> int: return self.ex_iterable.n_shards +class FilteredExamplesIterable(_BaseExamplesIterable): + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + function: Callable, + with_indices: bool = False, + input_columns: Optional[List[str]] = None, + batched: bool = False, + batch_size: int = 1000, + ): + self.ex_iterable = ex_iterable + self.function = function + self.batched = batched + self.batch_size = batch_size + self.with_indices = with_indices + self.input_columns = input_columns + + def __iter__(self): + iterator = iter(self.ex_iterable) + current_idx = 0 + if self.batched: + for key, example in iterator: + # If batched, first build the batch + key_examples_list = [(key, example)] + [ + (key, example) for key, example in islice(iterator, self.batch_size - 1) + ] + keys, examples = zip(*key_examples_list) + batch = _examples_to_batch(examples) + # then compute the mask for the batch + inputs = batch + function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] + if self.with_indices: + function_args.append([current_idx + i for i in range(len(key_examples_list))]) + mask = self.function(*function_args) + # yield one example at a time from the batch + for batch_idx, (key_example, to_keep) in enumerate(zip(key_examples_list, mask)): + if to_keep: + yield key_example + current_idx += batch_idx + 1 + else: + for key, example in iterator: + # If not batched, we can apply the filtering function direcly + inputs = dict(example) + function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] + if self.with_indices: + function_args.append(current_idx) + to_keep = self.function(*function_args) + if to_keep: + yield key, example + current_idx += 1 + + def shuffle_data_sources(self, seed: Optional[int]) -> "MappedExamplesIterable": + """Shuffle the wrapped examples iterable.""" + return FilteredExamplesIterable( + self.ex_iterable.shuffle_data_sources(seed), + function=self.function, + with_indices=self.with_indices, + input_columns=self.input_columns, + batched=self.batched, + batch_size=self.batch_size, + ) + + @property + def n_shards(self) -> int: + return self.ex_iterable.n_shards + + class BufferShuffledExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator): self.ex_iterable = ex_iterable @@ -469,7 +536,7 @@ def map( batched: bool = False, batch_size: int = 1000, remove_columns: Optional[Union[str, List[str]]] = None, - ): + ) -> "IterableDataset": """ Apply a function to all the examples in the iterable dataset (individually or in batches) and update them. If your function returns a column that already exists, then it overwrites it. @@ -522,6 +589,58 @@ def map( shuffling=copy.deepcopy(self._shuffling), ) + def filter( + self, + function: Optional[Callable] = None, + with_indices=False, + input_columns: Optional[Union[str, List[str]]] = None, + batched: bool = False, + batch_size: Optional[int] = 1000, + ) -> "IterableDataset": + """Apply a filter function to all the elements so that the dataset only includes examples according to the filter function. + The filtering is done on-the-fly when iterating over the dataset. + + Args: + function (:obj:`Callable`): Callable with one of the following signatures: + + - ``function(example: Union[Dict, Any]) -> bool`` if ``with_indices=False, batched=False`` + - ``function(example: Union[Dict, Any], indices: int) -> bool`` if ``with_indices=True, batched=False`` + - ``function(example: Union[Dict, Any]) -> List[bool]`` if ``with_indices=False, batched=True`` + - ``function(example: Union[Dict, Any], indices: int) -> List[bool]`` if ``with_indices=True, batched=True`` + + If no function is provided, defaults to an always True function: ``lambda x: True``. + with_indices (:obj:`bool`, default `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`. + input_columns (:obj:`str` or `List[str]`, optional): The columns to be passed into `function` as + positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument. + batched (:obj:`bool`, defaults to `False`): Provide batch of examples to `function` + batch_size (:obj:`int`, optional, default ``1000``): Number of examples per batch provided to `function` if `batched=True`. + """ + if isinstance(input_columns, str): + input_columns = [input_columns] + + # TODO(QL): keep the features (right now if we keep it it would call decode_example again on an already decoded example) + info = copy.deepcopy(self._info) + info.features = None + + # We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here + ex_iterable = FilteredExamplesIterable( + TypedExamplesIterable(self._ex_iterable, self._info.features) + if self._info.features is not None + else self._ex_iterable, + function=function, + with_indices=with_indices, + input_columns=input_columns, + batched=batched, + batch_size=batch_size, + ) + return iterable_dataset( + ex_iterable=ex_iterable, + info=info, + split=self._split, + format_type=self._format_type, + shuffling=copy.deepcopy(self._shuffling), + ) + def shuffle( self, seed=None, generator: Optional[np.random.Generator] = None, buffer_size: int = 1000 ) -> "IterableDataset": diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 3d619dbe9c4..c2245e68561 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -11,6 +11,7 @@ BufferShuffledExamplesIterable, CyclingMultiSourcesExamplesIterable, ExamplesIterable, + FilteredExamplesIterable, IterableDataset, MappedExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, @@ -301,6 +302,94 @@ def test_mapped_examples_iterable_input_columns(generate_examples_fn, n, func, b assert list(x for _, x in ex_iterable) == expected +@pytest.mark.parametrize( + "n, func, batch_size", + [ + (3, lambda x: x["id"] % 2 == 0, None), # keep even number + (3, lambda x: [x["id"][0] % 2 == 0], 1), # same with bs=1 + (5, lambda x: [i % 2 == 0 for i in x["id"]], 10), # same with bs=10 + (25, lambda x: [i % 2 == 0 for i in x["id"]], 10), # same with bs=10 + (3, lambda x: False, None), # return 0 examples + (3, lambda x: [False] * len(x["id"]), 10), # same with bs=10 + ], +) +def test_filtered_examples_iterable(generate_examples_fn, n, func, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = FilteredExamplesIterable( + base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if batch_size is None: + expected = [x for x in all_examples if func(x)] + else: + # For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function + expected = [] + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = _examples_to_batch(examples) + mask = func(batch) + expected.extend([x for x, to_keep in zip(examples, mask) if to_keep]) + if expected: + assert next(iter(ex_iterable))[1] == expected[0] + assert list(x for _, x in ex_iterable) == expected + + +@pytest.mark.parametrize( + "n, func, batch_size", + [ + (3, lambda x, index: index % 2 == 0, None), # keep even number + (25, lambda x, indices: [idx % 2 == 0 for idx in indices], 10), # same with bs=10 + ], +) +def test_filtered_examples_iterable_with_indices(generate_examples_fn, n, func, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = FilteredExamplesIterable( + base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, with_indices=True + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if batch_size is None: + expected = [x for idx, x in enumerate(all_examples) if func(x, idx)] + else: + # For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function + expected = [] + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = _examples_to_batch(examples) + indices = list(range(batch_offset, batch_offset + len(examples))) + mask = func(batch, indices) + expected.extend([x for x, to_keep in zip(examples, mask) if to_keep]) + assert next(iter(ex_iterable))[1] == expected[0] + assert list(x for _, x in ex_iterable) == expected + + +@pytest.mark.parametrize( + "n, func, batch_size, input_columns", + [ + (3, lambda id_: id_ % 2 == 0, None, ["id"]), # keep even number + (25, lambda ids_: [i % 2 == 0 for i in ids_], 10, ["id"]), # same with bs=10 + ], +) +def test_filtered_examples_iterable_input_columns(generate_examples_fn, n, func, batch_size, input_columns): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = FilteredExamplesIterable( + base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, input_columns=input_columns + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + columns_to_input = input_columns if isinstance(input_columns, list) else [input_columns] + if batch_size is None: + expected = [x for x in all_examples if func(*[x[col] for col in columns_to_input])] + else: + # For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function + expected = [] + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = _examples_to_batch(examples) + mask = func(*[batch[col] for col in columns_to_input]) + expected.extend([x for x, to_keep in zip(examples, mask) if to_keep]) + assert next(iter(ex_iterable))[1] == expected[0] + assert list(x for _, x in ex_iterable) == expected + + def test_skip_examples_iterable(generate_examples_fn): total, count = 10, 2 base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": total})