Skip to content

Commit

Permalink
Adding a new align_to_words param to qa pipeline. (huggingface#18010)
Browse files Browse the repository at this point in the history
* Adding a new `align_to_words` param to qa pipeline.

* Update src/transformers/pipelines/question_answering.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Import protection.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and amyeroberts committed Aug 17, 2022
1 parent cafb76e commit a25b1b3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
48 changes: 38 additions & 10 deletions src/transformers/pipelines/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from ..data import SquadExample, SquadFeatures, squad_convert_examples_to_features
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import PaddingStrategy, add_end_docstrings, is_tf_available, is_torch_available, logging
from ..utils import (
PaddingStrategy,
add_end_docstrings,
is_tf_available,
is_tokenizers_available,
is_torch_available,
logging,
)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline


Expand All @@ -18,6 +25,9 @@
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel

if is_tokenizers_available():
import tokenizers

if is_tf_available():
import tensorflow as tf

Expand Down Expand Up @@ -180,6 +190,7 @@ def _sanitize_parameters(
max_seq_len=None,
max_question_len=None,
handle_impossible_answer=None,
align_to_words=None,
**kwargs
):
# Set defaults values
Expand Down Expand Up @@ -208,6 +219,8 @@ def _sanitize_parameters(
postprocess_params["max_answer_len"] = max_answer_len
if handle_impossible_answer is not None:
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
if align_to_words is not None:
postprocess_params["align_to_words"] = align_to_words
return preprocess_params, {}, postprocess_params

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -243,6 +256,9 @@ def __call__(self, *args, **kwargs):
The maximum length of the question after tokenization. It will be truncated if needed.
handle_impossible_answer (`bool`, *optional*, defaults to `False`):
Whether or not we accept impossible as an answer.
align_to_words (`bool`, *optional*, defaults to `True`):
Attempts to align the answer to real words. Improves quality on space separated langages. Might hurt on
non-space-separated languages (like Japanese or Chinese)
Return:
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
Expand Down Expand Up @@ -386,6 +402,7 @@ def postprocess(
top_k=1,
handle_impossible_answer=False,
max_answer_len=15,
align_to_words=True,
):
min_null_score = 1000000 # large and positive
answers = []
Expand Down Expand Up @@ -464,15 +481,8 @@ def postprocess(
for s, e, score in zip(starts, ends, scores):
s = s - offset
e = e - offset
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]

start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words)

answers.append(
{
Expand All @@ -490,6 +500,24 @@ def postprocess(
return answers[0]
return answers

def get_indices(
self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool
) -> Tuple[int, int]:
if align_to_words:
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
else:
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
return start_index, end_index

def decode(
self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
) -> Tuple:
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,29 @@ def ensure_large_logits_postprocess(

self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})

@slow
@require_torch
def test_small_model_japanese(self):
question_answerer = pipeline(
"question-answering",
model="KoichiYasuoka/deberta-base-japanese-aozora-ud-head",
)
output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている")

# Wrong answer, the whole text is identified as one "word" since the tokenizer does not include
# a pretokenizer
self.assertEqual(
nested_simplify(output),
{"score": 1.0, "start": 0, "end": 30, "answer": "全学年にわたって小学校の国語の教科書に挿し絵が用いられている"},
)

# Disable word alignment
output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている", align_to_words=False)
self.assertEqual(
nested_simplify(output),
{"score": 1.0, "start": 15, "end": 18, "answer": "教科書"},
)

@slow
@require_torch
def test_small_model_long_context_cls_slow(self):
Expand Down

0 comments on commit a25b1b3

Please sign in to comment.