Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: basic token streaming #22449

Merged
merged 8 commits into from
Mar 30, 2023
Merged

Generate: basic token streaming #22449

merged 8 commits into from
Mar 30, 2023

Conversation

gante
Copy link
Member

@gante gante commented Mar 29, 2023

What does this PR do?

Adds token streaming to .generate() 🎉

Why now?

I want to showcase and communicate how much faster assisted generation can be... and for that, I need token streaming :D Non-image/video results have a much lower impact.

What's being added

This PR adds a streamer input to generate. If it is non-None, generate will call streamer.put(new_tokens) as they are being generated. streamer can, therefore, be a wide array of things. This PR adds the simplest case: print tokens as they are generated.

At first, I thought of adding a simpler stream=True option. However, the tokenizer would have to be passed into .generate(), which we have been avoiding, and it wouldn't be nearly as flexible. I've made the call to make streaming+.generate() flexible, and to keep it simple at a pipeline level.

If this PR gets accepted

The plan is to:

  1. Communicate this feature on Twitter (w/Colab examples)
  2. Add to pipelines, maybe with a simpler stream=True flag to start
  3. Add Gradio examples (and, if needed, a specific streamer class)
  4. Add the beam search case to the streamer classes (beam search is much trickier -- we should only print tokens when all candidate beams agree, which means logic needs to be added)

How does it look

Here's an example. Note that it is running on CPU, so we can actually see the streaming effect (3090 is too fast 😅 ). On GPU it also streams, but much faster 🔥

Screen.Recording.2023-03-29.at.16.39.55.mov

@gante gante requested a review from sgugger March 29, 2023 15:55
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of this API at all. Why is there a need for the TextStreamer to spawn a new process? The put method could directly call the print statement.

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/streamers.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 29, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante
Copy link
Member Author

gante commented Mar 29, 2023

@sgugger revised with the simpler implementation (no context manager nor multiprocessing) 🤗

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better this way, thanks!

@oobabooga
Copy link
Contributor

Just a FYI: I have been doing this using transformers.StoppingCriteria to create a callback:

class Stream(transformers.StoppingCriteria):
    def __init__(self, callback_func=None):
        self.callback_func = callback_func

    def __call__(self, input_ids, scores) -> bool:
        if self.callback_func is not None:
            self.callback_func(input_ids[0])
        return False

The callback is then used to create an iterator with the Iteratorize class here: https://github.com/oobabooga/text-generation-webui/blob/main/modules/callbacks.py#L42

Usage becomes:

def generate_with_callback(callback=None, **kwargs):
    kwargs['stopping_criteria'].append(Stream(callback_func=callback))
    with torch.no_grad():
        shared.model.generate(**kwargs)

def generate_with_streaming(**kwargs):
    return Iteratorize(generate_with_callback, kwargs, callback=None)

with generate_with_streaming(**generate_params) as generator:
    for output in generator:

@gante
Copy link
Member Author

gante commented Mar 31, 2023

@oobabooga 🧠 That's a smart (and unexpected!) use of the stopping criteria.

I'm going to work on a standardized Gradio solution today, and a Queue+iterator was indeed my plan. If you don't mind, I will take inspiration in your code 💛

A question regarding your implementation -- you use a separate thread in the Iteratorize, not a separate process. Any reason for in picking a thread over a process? (Without running the code, I'd argue in favor of a separate thread for GIL purposes)

@oobabooga
Copy link
Contributor

oobabooga commented Mar 31, 2023

If you don't mind, I will take inspiration in your code

Feel free to copy anything you want.

Any reason for in picking a thread over a process?

Honestly, I have no specific reason to give. I just spent several days trying to get the text generation to run in the background independently of where the for loop was at in the queue, and this is what ended up working. With this, I get close to as many tokens/s with streaming as without.

raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants