-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: basic token streaming #22449
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not a fan of this API at all. Why is there a need for the TextStreamer
to spawn a new process? The put
method could directly call the print statement.
The documentation is not available anymore as the PR was closed or merged. |
@sgugger revised with the simpler implementation (no context manager nor multiprocessing) 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better this way, thanks!
Just a FYI: I have been doing this using class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False The callback is then used to create an iterator with the Iteratorize class here: https://github.com/oobabooga/text-generation-webui/blob/main/modules/callbacks.py#L42 Usage becomes: def generate_with_callback(callback=None, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
with torch.no_grad():
shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
with generate_with_streaming(**generate_params) as generator:
for output in generator: |
@oobabooga 🧠 That's a smart (and unexpected!) use of the stopping criteria. I'm going to work on a standardized Gradio solution today, and a Queue+iterator was indeed my plan. If you don't mind, I will take inspiration in your code 💛 A question regarding your implementation -- you use a separate thread in the |
Feel free to copy anything you want.
Honestly, I have no specific reason to give. I just spent several days trying to get the text generation to run in the background independently of where the |
* haha tokens go brrrr
* haha tokens go brrrr
What does this PR do?
Adds token streaming to
.generate()
🎉Why now?
I want to showcase and communicate how much faster assisted generation can be... and for that, I need token streaming :D Non-image/video results have a much lower impact.
What's being added
This PR adds a
streamer
input to generate. If it is non-None
, generate will callstreamer.put(new_tokens)
as they are being generated.streamer
can, therefore, be a wide array of things. This PR adds the simplest case: print tokens as they are generated.At first, I thought of adding a simpler
stream=True
option. However, the tokenizer would have to be passed into.generate()
, which we have been avoiding, and it wouldn't be nearly as flexible. I've made the call to make streaming+.generate()
flexible, and to keep it simple at apipeline
level.If this PR gets accepted
The plan is to:
stream=True
flag to startHow does it look
Here's an example. Note that it is running on CPU, so we can actually see the streaming effect (3090 is too fast 😅 ). On GPU it also streams, but much faster 🔥
Screen.Recording.2023-03-29.at.16.39.55.mov