diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 0f5fbf0370e708..d58762035ef7f8 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -1,3 +1,4 @@ +import types import warnings from collections.abc import Iterable from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union @@ -22,8 +23,11 @@ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING + Dataset = None + if is_torch_available(): import torch + from torch.utils.data import Dataset from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING @@ -82,6 +86,11 @@ def __call__(self, *args, **kwargs): else: raise ValueError(f"Unknown arguments {kwargs}") + # When user is sending a generator we need to trust it's a valid example + generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,) + if isinstance(inputs, generator_types): + return inputs + # Normalize inputs if isinstance(inputs, dict): inputs = [inputs] @@ -245,12 +254,18 @@ def __call__(self, *args, **kwargs): """ # Convert inputs to features + examples = self._args_parser(*args, **kwargs) - if len(examples) == 1: + if isinstance(examples, (list, tuple)) and len(examples) == 1: return super().__call__(examples[0], **kwargs) return super().__call__(examples, **kwargs) def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None): + # XXX: This is specal, args_parser will not handle anything generator or dataset like + # For those we expect user to send a simple valid example either directly as a SquadExample or simple dict. + # So we still need a little sanitation here. + if isinstance(example, dict): + example = SquadExample(None, example["question"], example["context"], None, None, None) if max_seq_len is None: max_seq_len = min(self.tokenizer.model_max_length, 384) diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index f34237612c11a9..c3a0da2f2b5e9a 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -125,6 +125,18 @@ def test_small_model_pt(self): self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch + def test_small_model_pt_iterator(self): + # https://github.com/huggingface/transformers/issues/18510 + pipe = pipeline(model="sshleifer/tiny-distilbert-base-cased-distilled-squad", batch_size=16, framework="pt") + + def data(): + for i in range(10): + yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."} + + for outputs in pipe(data()): + self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch def test_small_model_pt_softmax_trick(self): question_answerer = pipeline(