Skip to content

Commit

Permalink
Adding a new return_full_text parameter to TextGenerationPipeline. (h…
Browse files Browse the repository at this point in the history
…uggingface#9852)

* Adding a new `return_full_text` parameter to TextGenerationPipeline.

For text-generation, it's sometimes used as prompting text.
In that context, prefixing `generated_text` with the actual input
forces the caller to take an extra step to remove it.

The proposed change adds a new parameter (for backward compatibility).
`return_full_text` that enables the caller to prevent adding the prefix.

* Doc quality.
  • Loading branch information
Narsil committed Jan 29, 2021
1 parent bc109ae commit c2d0ffe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ class TextGenerationPipeline(Pipeline):
"TFCTRLLMHeadModel",
]

def __init__(self, *args, **kwargs):
def __init__(self, *args, return_full_text=True, **kwargs):
super().__init__(*args, **kwargs)

self.check_model_type(self.ALLOWED_MODELS)
self.return_full_text = return_full_text

# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, **kwargs):
Expand All @@ -65,6 +66,7 @@ def __call__(
text_inputs,
return_tensors=False,
return_text=True,
return_full_text=None,
clean_up_tokenization_spaces=False,
prefix=None,
**generate_kwargs
Expand All @@ -79,6 +81,9 @@ def __call__(
Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to include the decoded texts in the outputs.
return_full_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`False` only added text is returned, otherwise the full text is returned Only meaningful
if `return_text` is set to True.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`):
Expand All @@ -94,14 +99,15 @@ def __call__(
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text.
"""
prefix = prefix if prefix is not None else self.model.config.prefix
return_full_text = return_full_text if return_full_text is not None else self.return_full_text

if isinstance(text_inputs, str):
text_inputs = [text_inputs]
results = []
for prompt_text in text_inputs:
# Manage correct placement of the tensors
with self.device_placement():
prefix = prefix if prefix is not None else self.model.config.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
Expand Down Expand Up @@ -168,7 +174,12 @@ def __call__(
)
)

record["generated_text"] = prompt_text + text[prompt_length:]
if return_full_text:
all_text = prompt_text + text[prompt_length:]
else:
all_text = text[prompt_length:]

record["generated_text"] = all_text

result.append(record)
results += [result]
Expand Down
19 changes: 19 additions & 0 deletions tests/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

from transformers import pipeline
from transformers.testing_utils import require_torch

from .test_pipelines_common import MonoInputPipelineCommonMixin

Expand All @@ -41,3 +42,21 @@ def test_simple_generation(self):
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[1][0]["generated_text"]), str)

@require_torch
def test_generation_output_style(self):
text_generator = pipeline(task="text-generation", model=self.small_models[0])
# text-generation is non-deterministic by nature, we can't fully test the output

outputs = text_generator("This is a test")
self.assertIn("This is a test", outputs[0]["generated_text"])

outputs = text_generator("This is a test", return_full_text=False)
self.assertNotIn("This is a test", outputs[0]["generated_text"])

text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False)
outputs = text_generator("This is a test")
self.assertNotIn("This is a test", outputs[0]["generated_text"])

outputs = text_generator("This is a test", return_full_text=True)
self.assertIn("This is a test", outputs[0]["generated_text"])

0 comments on commit c2d0ffe

Please sign in to comment.