Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: basic token streaming #22449

Merged
merged 8 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add documentation; More robust word-by-word printing in TextStreamer;…
… Doctests
  • Loading branch information
gante committed Mar 30, 2023
commit 8d481e824c3161a28f7397a7b116d5ca197ee82f
23 changes: 23 additions & 0 deletions docs/source/en/generation_strategies.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,29 @@ one for summarization with beam search). You must have the right Hub permissions
['Les fichiers de configuration sont faciles à utiliser !']
```

## Streaming

The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance
from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and
`end()` is used to flag the end of text generation.

In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes
ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into
your screen, one word at a time:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)

>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```

## Decoding strategies

Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
Expand Down
52 changes: 42 additions & 10 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def end(self):

class TextStreamer(BaseStreamer):
"""
Simple text streamer that prints a token as soon as it gets them.
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.

Parameters:
tokenizer (`AutoTokenizer`):
Expand All @@ -47,26 +47,58 @@ class TextStreamer(BaseStreamer):
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

>>> tok = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> inputs = tok(["This cat is"], return_tensors="pt")
>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> model.generate(**inputs, streamer=streamer)

>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
self.tokenizer = tokenizer
self.token_cache = []
self.print_len = 0

def put(self, value):
"""Prints the token(s) to stdout"""
"""
Recives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
text = self.tokenizer.decode(value)
print(text, flush=True, end="")

# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache)

# After symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)

print(printable_text, flush=True, end="")

def end(self):
"""Prints a newline to stdout"""
print("", flush=True)
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""

# Print a newline (and the remaining text, if any)
print(printable_text, flush=True)