Skip to content

Commit

Permalink
[DX fix] Fixing QA pipeline streaming a dataset. (huggingface#18516)
Browse files Browse the repository at this point in the history
* [DX fix] Fixing QA pipeline streaming a dataset.

QuestionAnsweringArgumentHandler would iterate over the whole dataset
effectively killing all properties of the pipeline.
This restores nice properties when using `Dataset` or `Generator` since
those are meant to be consumed lazily.

* Handling TF better.
  • Loading branch information
Narsil authored and oneraghavan committed Sep 26, 2022
1 parent d6326c8 commit 1550bda
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/transformers/pipelines/question_answering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import types
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
Expand All @@ -22,8 +23,11 @@

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING

Dataset = None

if is_torch_available():
import torch
from torch.utils.data import Dataset

from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING

Expand Down Expand Up @@ -82,6 +86,11 @@ def __call__(self, *args, **kwargs):
else:
raise ValueError(f"Unknown arguments {kwargs}")

# When user is sending a generator we need to trust it's a valid example
generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,)
if isinstance(inputs, generator_types):
return inputs

# Normalize inputs
if isinstance(inputs, dict):
inputs = [inputs]
Expand Down Expand Up @@ -245,12 +254,18 @@ def __call__(self, *args, **kwargs):
"""

# Convert inputs to features

examples = self._args_parser(*args, **kwargs)
if len(examples) == 1:
if isinstance(examples, (list, tuple)) and len(examples) == 1:
return super().__call__(examples[0], **kwargs)
return super().__call__(examples, **kwargs)

def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):
# XXX: This is specal, args_parser will not handle anything generator or dataset like
# For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
# So we still need a little sanitation here.
if isinstance(example, dict):
example = SquadExample(None, example["question"], example["context"], None, None, None)

if max_seq_len is None:
max_seq_len = min(self.tokenizer.model_max_length, 384)
Expand Down
12 changes: 12 additions & 0 deletions tests/pipelines/test_pipelines_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ def test_small_model_pt(self):

self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})

@require_torch
def test_small_model_pt_iterator(self):
# https://github.com/huggingface/transformers/issues/18510
pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt")

def data():
for i in range(10):
yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}

for outputs in pipe(data()):
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})

@require_torch
def test_small_model_pt_softmax_trick(self):
question_answerer = pipeline(
Expand Down

0 comments on commit 1550bda

Please sign in to comment.