From 47940177eda240fe7986dc31c3c6a8cd8d949c6b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 Aug 2022 18:50:02 +0200 Subject: [PATCH] Adding a new `align_to_words` param to qa pipeline. (#18010) * Adding a new `align_to_words` param to qa pipeline. * Update src/transformers/pipelines/question_answering.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Import protection. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../pipelines/question_answering.py | 48 +++++++++++++++---- .../test_pipelines_question_answering.py | 23 +++++++++ 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index d58762035ef7f8..6f07382dc57c6b 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -8,7 +8,14 @@ from ..data import SquadExample, SquadFeatures, squad_convert_examples_to_features from ..modelcard import ModelCard from ..tokenization_utils import PreTrainedTokenizer -from ..utils import PaddingStrategy, add_end_docstrings, is_tf_available, is_torch_available, logging +from ..utils import ( + PaddingStrategy, + add_end_docstrings, + is_tf_available, + is_tokenizers_available, + is_torch_available, + logging, +) from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline @@ -18,6 +25,9 @@ from ..modeling_tf_utils import TFPreTrainedModel from ..modeling_utils import PreTrainedModel + if is_tokenizers_available(): + import tokenizers + if is_tf_available(): import tensorflow as tf @@ -180,6 +190,7 @@ def _sanitize_parameters( max_seq_len=None, max_question_len=None, handle_impossible_answer=None, + align_to_words=None, **kwargs ): # Set defaults values @@ -208,6 +219,8 @@ def _sanitize_parameters( postprocess_params["max_answer_len"] = max_answer_len if handle_impossible_answer is not None: postprocess_params["handle_impossible_answer"] = handle_impossible_answer + if align_to_words is not None: + postprocess_params["align_to_words"] = align_to_words return preprocess_params, {}, postprocess_params def __call__(self, *args, **kwargs): @@ -243,6 +256,9 @@ def __call__(self, *args, **kwargs): The maximum length of the question after tokenization. It will be truncated if needed. handle_impossible_answer (`bool`, *optional*, defaults to `False`): Whether or not we accept impossible as an answer. + align_to_words (`bool`, *optional*, defaults to `True`): + Attempts to align the answer to real words. Improves quality on space separated langages. Might hurt on + non-space-separated languages (like Japanese or Chinese) Return: A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys: @@ -386,6 +402,7 @@ def postprocess( top_k=1, handle_impossible_answer=False, max_answer_len=15, + align_to_words=True, ): min_null_score = 1000000 # large and positive answers = [] @@ -464,15 +481,8 @@ def postprocess( for s, e, score in zip(starts, ends, scores): s = s - offset e = e - offset - try: - start_word = enc.token_to_word(s) - end_word = enc.token_to_word(e) - start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0] - end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1] - except Exception: - # Some tokenizers don't really handle words. Keep to offsets then. - start_index = enc.offsets[s][0] - end_index = enc.offsets[e][1] + + start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words) answers.append( { @@ -490,6 +500,24 @@ def postprocess( return answers[0] return answers + def get_indices( + self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool + ) -> Tuple[int, int]: + if align_to_words: + try: + start_word = enc.token_to_word(s) + end_word = enc.token_to_word(e) + start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0] + end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1] + except Exception: + # Some tokenizers don't really handle words. Keep to offsets then. + start_index = enc.offsets[s][0] + end_index = enc.offsets[e][1] + else: + start_index = enc.offsets[s][0] + end_index = enc.offsets[e][1] + return start_index, end_index + def decode( self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray ) -> Tuple: diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index c3a0da2f2b5e9a..001254aa94b01e 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -171,6 +171,29 @@ def ensure_large_logits_postprocess( self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"}) + @slow + @require_torch + def test_small_model_japanese(self): + question_answerer = pipeline( + "question-answering", + model="KoichiYasuoka/deberta-base-japanese-aozora-ud-head", + ) + output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている") + + # Wrong answer, the whole text is identified as one "word" since the tokenizer does not include + # a pretokenizer + self.assertEqual( + nested_simplify(output), + {"score": 1.0, "start": 0, "end": 30, "answer": "全学年にわたって小学校の国語の教科書に挿し絵が用いられている"}, + ) + + # Disable word alignment + output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている", align_to_words=False) + self.assertEqual( + nested_simplify(output), + {"score": 1.0, "start": 15, "end": 18, "answer": "教科書"}, + ) + @slow @require_torch def test_small_model_long_context_cls_slow(self):