Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Multi-Step Scheduling #6854

Open
Tracked by #6801
SolitaryThinker opened this issue Jul 26, 2024 · 13 comments
Open
Tracked by #6801

[RFC]: Multi-Step Scheduling #6854

SolitaryThinker opened this issue Jul 26, 2024 · 13 comments
Labels

Comments

@SolitaryThinker
Copy link
Contributor

SolitaryThinker commented Jul 26, 2024

Motivation.

TLDR; There is high CPU overhead associated with each decode batch due to the processing and generation of input/output. Multi-step decoding will be able to amortize all these overheads over n-steps at a time.

  • Transfer of the sampled token from GPU to CPU for de-tokenization and response to client
  • Generation of output for user - Pythonization of tensors into python objects
  • CPU preparation and generation of next step’s input metadata
  • vLLM scheduler

Result is that GPU is often idle, waiting for CPU operations (5-13ms of GPU bubble)

Multi-step is when multiple decode passes are performed before performing a GPU-CPU sync in order to invoke vLLM scheduler and process sampled tokens. Currently the GPU->CPU memory transfer for sampled tokens is also synchronous with each decode step causing bubbles on the GPU. With multi-step, this memory transfer can happen in a separate CUDA stream and is essentially free as the CPU runs ahead of GPU.

See below for the source of performance improvement.

  • Both screenshots are about 200ms in duration. Top row is the CUDA kernels and bottom contains the python trace.
  • Each highlighted redbox is about 4ms of GPU bubble for both images.
  • For baseline, this overhead is incurred on every decode.
  • With multi-step-8, the 4ms only needs to be incurred once for every 8 decode iterations.

Torch Profiles
Baseline 8B on 1xH100

Screenshot 2024-07-26 at 3 28 00 PM

Multi-Step-8 8B on 1xH100
Screenshot 2024-07-26 at 3 28 33 PM

Benchmarks

  • ShareGPT using benchmark_serving.py
  • Infinite request rate
  • 1k requests
  • With Cudagraph
  • No chunked prefill
  • Input/Output length is from shareGPT dataset

MS = multi-step
MS-8 = 8-multi-steps before calling vLLM scheduler and process_output

Single GPU Baseline (Req/s) MS-8 (Req/s) MS-16 (Req/s)
A10G 8B Llama 5.20 5.89 -
H100 8B Llama 20.66 40.06 43.31
H100 30B Llama 9.23 13.09 13.23

Proposed Change.

Extend ExecuteModelRequest (input to Workers) and RequestOutput/SamplerOutput to include metadata for the multi-step state and modify existing ModelRunner to properly handle multi-step state. AsyncLLMEngine/LLMEngine will need to be modified to be aware of multi-step in order to call into the VLLM scheduler after n-steps instead of on every decode. The existing PP scheduling will not be changed.

High level Algorithm:

  1. Scheduler
  • We have fixed n steps and allocated additional blocks.
  • Only for decoding, not prefill. Prefill runs in the same way.
  1. At each worker
  • We prepare initial inputs the same way.
  • Run a model.
  • Sampler doesn't synchronize cpu <> gpu, but generates a next token only in gpu.
  • At each iteration, we broadcast tokens to all workers.
  • Update inputs for the next step. We use Cuda kernels for faster updates because Torch is too slow.
  • Asynchronously transfer sampled tokens to CPU.

Details:
Multi-step states that need to be track for each (micro)batch:

  • Current step that the batch is on - remaining lookahead slots available
  • sampled_token_ids - to keep track of sampled tokens still on GPU
  • sampler_output_ready_event - CUDA event to make sure we only pythonize if the GPU sampling is finished
  • CUDA event for any forward passes that have not completed yet
  • Any buffers that might be needed for async in-place update of attention metadata (depends on the backend)

Core changes to Engine:

  • Add attribute to scheduler config, engine argument, and CLI to enable vLLM scheduler to return lookahead slots (previous only for spec-decode)
  • Skip vLLM scheduler invocation if we have not run out of lookahead slots for a batch of decodes
  • Capture pythonized outputs as they become ready to return to the client.

Core changes to ModelRunner:

  • For TP/PP: Broadcast the sampled token to all other ranks in order for each of them to call advance_step
  • Synchronize using CudaEvents with the previous forward passes to make sure the CPU does not clobber any GPU tensors currently in-use when preparing inputs for the next step.
  • Synchronize with previous forward pass’s sampler and start GPU->CPU transfer in separate Cuda stream
  • Pythonize any ready GPU tensors if the CPU is running ahead.
  • Invoke the correct advance_step for in-place updating of next step’s input metadata
  • Make sure to block for any remaining forward passes or GPU-> CPU transfers if out of lookahead slots so that Engine can call into vLLM scheduler

Prototype:
The current prototype is based on speculative decode’s T1DraftModelRunner’s logic. There are numerous additions for PP/TP support. For the prototype we created a non-spec decode MultiStepModelRunner under workers/. The goal is that we will generalize this to the existing ModelRunner (removing the need for a new file) before merging.

Reasoning: PP+multi-step
TLDR: Since the current multi-step look is inside ModelRunner/Worker, PP scheduling in Executor will cause bubbles between each step and not interleave the steps of Batch 1 (VE1) with Batch 2 (VE2)

Feedback Period.

No response

CC List.

@zhisbug @Yard1 @WoosukKwon @rkooo567 @zhuohan123 @simon-mo @comaniac @megha95 @richardliaw

Any Other Things.

Much thanks to @Yard1 for extensive help with design and implementation!

Sync with @megha for ongoing work to make the output_processor async. She proposed to move sampler out of model runner.

@rkooo567
Copy link
Collaborator

The result seems pretty impressive!

@zhuohan123
Copy link
Member

+1 on prioritizing this. Really great result!

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

Wow! These results look great. Happy to help out as we can

@megha95
Copy link
Contributor

megha95 commented Jul 27, 2024

@SolitaryThinker thanks for writing out a clear proposal, results look great! I have some followup questions:

  1. How will streaming work with this? Given 8 multi steps will happen in a step, will the outputs be streamed to user after every 8 decoded tokens?
  2. I expect TTFT for a new request will increase with this design as new prefill requests will only start at every 8 steps.
  3. What if there's a stop token id that's decoded in 1st sub-step (calling multi-steps inside one large step as sub-steps)? Is the decoding continued even though there's a stop token id?

@SolitaryThinker
Copy link
Contributor Author

SolitaryThinker commented Jul 28, 2024

@megha95 thanks for the great questions!

  1. The outputs will be able to be streamed both as they finish or as a larger chunk depending on the following. If the CPU is able to run ahead of GPU, then the pythonization will happen asynchronously, essentially for "free" and be available to be streamed as individual tokens. However if the GPU is ahead, then pythonization on remaining/all steps will occur synchronously after the last forward finishes. We want to keep the time GPU is blocked on CPU as small as possible so only perform pythonization if CPU is ahead.
    Note that the sampled token ids will be available on CPU asap as we perform the GPU<>CPU transfer in a different stream. So perhaps the pythonization can also be moved out of ModelRunner and into a separate (process_output) worker?

  2. Yes, TTFT will increase. However we plan on adding support for chunked-prefill and perhaps a mechanism to dynamically add/remove SequenceGroups from the multi-step batch which would address this as well as wasted steps

  3. Currently decode will continue, however anything that is past a EOS token is truncated before returning output.

@rkooo567
Copy link
Collaborator

What if there's a stop token id that's decoded in 1st sub-step (calling multi-steps inside one large step as sub-steps)? Is the decoding continued even though there's a stop token id?

I think for this one, it is also easy to add different policies. for example, we can do early return if number of eos tokens > X or something like that. But we will need more benchmark for these cases

@zhisbug
Copy link
Collaborator

zhisbug commented Jul 28, 2024

@SolitaryThinker : how's the compatibility with PP+TP going?

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Jul 29, 2024

@SolitaryThinker Thanks for the great work! I have a few comments:

  1. Inter-token latency is very important for us. I think streaming the output token every step is a must, not optional.
  2. How does this work with other features like prefix caching and chunked prefills? For me, it's a bit hard to imagine.

@SolitaryThinker
Copy link
Contributor Author

@zhisbug PP is partly working now (debugging some hanging issues), going to continue focus on getting PP+TP working as a priority

@SolitaryThinker
Copy link
Contributor Author

@WoosukKwon

  1. The PR currently will only call process_model_outputs at the end of n-multi-steps meaning tokens will be streamed together. @megha95 is working on refactoring output processing and streaming and we can integrate with that to properly stream tokens as soon as they are decoded and asynchronously transferred to CPU.

  2. Prefix caching and chunked-prefill definitely needs to be thought about more and won't be supported initially.

@jon-chuang
Copy link
Contributor

I guess this is incompatible with guided decoding (#5423), correct? Since guided decoding needs to see output tokens on every decode step.

@SolitaryThinker
Copy link
Contributor Author

SolitaryThinker commented Aug 7, 2024

@jon-chuang It should be possible to make it compatible. Currently each step's pythonized output is available to the output_processor (detokentization) not immediately after the decode step, but after the next step's decode - as we perform the pythonization after launching the next step in order to keep GPU as busy as possible. There are two things that would be needed to make it compatible with guided decode:

  1. Currently we only call output_processor once at the end of n-multi-steps (passing it n-pythonized outputs). I have not tried calling output_processor after each pythonized output is available, but it should be possible. If others know more, any insights would be great. @megha95 is working on making output_processor async altogether which should definitely be able to handle each output as they become available.

  2. As for the pythonized output being behind by one step, we can add a flag or force this pythonization to happen synchronously - if guided decode is enabled - with each step in order conform to any constraints for guided decode.

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 8, 2024

I think that #5423 mentions async prepare and apply as well. I also mentioned this possibility here: #7000 (comment)

It will be good to coordinate efforts with logit_processor API changes.

But I think for the time being you should not wait for these features to land, and simply throw an incompatibility error until such a feature lands.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

8 participants