Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Initial Support for Cloud TPUs #3620

Open
4 of 6 tasks
WoosukKwon opened this issue Mar 25, 2024 · 14 comments
Open
4 of 6 tasks

[RFC] Initial Support for Cloud TPUs #3620

WoosukKwon opened this issue Mar 25, 2024 · 14 comments
Labels
RFC tpu Related to Google TPUs

Comments

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Mar 25, 2024

Progress

Project Scope

This project focuses on making vLLM compatible with Google cloud TPUs. Our goal is seamless integration so users can easily run vLLM on TPUs for both online and offline inference. We will target common setups, like popular models such as Gemma, using the bfloat16 data type.

Target TPUs and Models

We will focus on the most recent generations of TPUs, namely TPU v4, v5e, and v5p, considering their superior performance to previous generations. We will start by making sure vLLM works with dense models such as Gemma. After that, we will expand support to Mixture-of-Experts (MoE) models such as Mixtral.

Features Not Included (for now)

The following features are outside the scope of this initial project, but we'd like to tackle them in the future:

  • Speculative decoding
  • GPTQ/AWQ Quantization
  • Multi-LoRA serving

Design

Overview

Screenshot 2024-03-25 at 10 43 50 AM

To integrate the TPU backend into vLLM, we will add the new TPU executor and TPU worker which are counterparts of the GPU executor and GPU worker, respectively. Unlike NVIDIA and AMD GPUs that share the same executor and worker, we create a separate code path for TPUs considering the significant difference between GPUs and TPUs. On the other hand, the two backends will share the other components of LLMEngine, namely the scheduler, KV cache manager, and tokenizer, as they are (almost) device agnostic.

PyTorch XLA and JAX

As many components of vLLM are device and runtime agnostic, it is possible to use JAX for TPU integration. However, for faster initial integration and maximum code reuse, we will start with PyTorch XLA. Adding JAX backend to vLLM will be interesting future work.

TPU Workers

Screenshot 2024-03-25 at 10 44 24 AM

For tensor-parallel inference, the vLLM TPU executor will spin up multiple TPU workers; one TPU worker per TPU chip. Specifically, we will use Ray to connect and manage the TPU workers which may reside in different TPU VMs. Note that we do not plan to support multi-slice inference at the moment, while we will support multi-host inference within the same TPU pod slice.

Same as the GPU executor, the TPU executor will use Megatron-style model partitioning for tensor-parallel inference. The partitioning strategy will be hardcoded into the model by replacing nn.Linear with RowParallelLinear and ColumnParallelLinear. Auto-sharding the model can be our future work.

GPU Executor vs. TPU Executor

Screenshot 2024-03-25 at 10 44 40 AM

For GPUs, vLLM uses both eager mode and CUDA graphs for model execution. Specifically, vLLM uses eager mode for prefills and CUDA graphs for decodes. vLLM currently does not use torch.compile for GPUs, but plans to use it in the future. For TPUs, on the other hand, vLLM will use torch.compile (with openxla_eval backend) to trace the PyTorch model and lower it into an XLA graph.

While vLLM’s GPU and TPU backends will take separate code paths, they will share the PyTorch model code. Most of the custom ops for GPUs will not be needed for TPUs, since they can be auto-generated by the XLA compiler. Therefore, for each target op, vLLM will have two implementations, _forward and _forward_cuda, and select either of the two implementations at run time depending on the hardware backend. For example, we can define the target ops/layers as follows:

class Op(nn.Module):

    def _forward(self,...):
        # PyTorch implementation that can be optimized by compilers
        # such as XLA or torch.compile.
        ...

    def _forward_cuda(self, ...):
        # Implementation using custom ops written in CUDA.
        ...

    def forward(self, ...):
        if ...:
            return self._forward_cuda(...)
        else:
            return self._forward(...)

Important exceptions to this are the FlashAttention and PagedAttention custom ops, which cannot be generated by the XLA compiler. We will use custom Pallas kernels for them.

Handling Dynamic Shapes

vLLM’s continuous batching has two phases: prefill and decode. vLLM dynamically switches between the two phases based on its scheduling decisions. The input tensor shape for prefills is [batch_size, prefill_len, hidden_size] while the input tensor shape for decodes is [batch_size, 1, hidden_size] since LLMs decode tokens one by one (here we do not consider special cases such as speculative decoding). In LLM inference, the batch_size and prefill_len can vary for every step.

To meet the XLA’s static shape requirement, we will bucketize the possible input shapes. For decodes, we will bucketize the batch_size dimension by creating buckets for batch_size=[8, 16, 24, 32, 40, …, 256]. For prefills, to reduce the number of compiled graphs, we will fix the batch_size to 1, and bucketize the prefill_len dimension by creating buckets for prefill_len=[8, 16, 32, 64, 128, …, max_model_len]. Given that each prefill input contains enough tokens to efficiently utilize TPUs, fixing batch_size as 1 will not hurt performance a lot. The specific bucket sizes will be tuned after benchmarking the compilation overhead and end-to-end performance.

References

@WoosukKwon WoosukKwon added RFC tpu Related to Google TPUs labels Mar 25, 2024
@youkaichao
Copy link
Member

Nit: image links are broken.

@WoosukKwon
Copy link
Collaborator Author

@youkaichao Thanks for letting me know! Just fixed it.

@simon-mo
Copy link
Collaborator

We will use custom Pallas kernels for them.

Can you elaborate on the custom Pallas kernel for PagedAttention? Is there any links?

@WoosukKwon
Copy link
Collaborator Author

Can you elaborate on the custom Pallas kernel for PagedAttention? Is there any links?

Good question. It's not open-sourced yet, but I was told that it will be released under the JAX repository in a week.

@rkooo567
Copy link
Collaborator

The input tensor shape for prefills is [batch_size, prefill_len, hidden_size] while the input tensor shape for decodes is [batch_size, 1, hidden_size]

Is this true after we moved to 1dquery? Or does it mean we need to support both 1d and 2d query inputs?

@miladm
Copy link

miladm commented Mar 26, 2024

Can you elaborate on the custom Pallas kernel for PagedAttention? Is there any links?

You can find a sample Pallas Kernel implementation in TorchXLA for FlashAttention. A similar mechanism would apply to other kernels.

Also cc @liangfu to review

@WoosukKwon
Copy link
Collaborator Author

Is this true after we moved to 1dquery? Or does it mean we need to support both 1d and 2d query inputs?

@rkooo567 I believe the change won't affect the TPU backend since GPUs and TPUs only share the scheduler, but not the worker and model runner. Also, we will introduce another attention backend for TPUs, so the changes in FlashAttentionBackend and XFormersBackend will not affect TPUs either.

@liangfu
Copy link
Contributor

liangfu commented Apr 2, 2024

Thanks for the proposal @WoosukKwon . I'm interested to learn a few more details:
1/ What is the proposed KV cache layout ?
2/ How are we going to use Ray to connect and manage the workers?
3/ How does ray recognize the connection pattern between chips, and communicate efficiently ?
4/ How are we going to effectively utilize bucketing while maintaining reasonable benefit for dynamic shape support ? (To my understanding, excessive bucketing would introduce massive amount of assembly code. )
5/ How are we going to distribute 8 KV heads (in GQA) among 64 chips? replicate 8x ?

@miladm
Copy link

miladm commented Apr 16, 2024

Here is the WIP PR for the PagedAttention kernel on Pallas + TorchXLA: pytorch/xla#6912. We expect it to land pretty soon.

cc @wonjoolee95

@Sea-Snell
Copy link

Hey! I'm really excited about TPU support for VLLM. I just wanted to check about support for larger multi-host pods, since it looks like it only supports single worker TPUs. Is this on the roadmap?

@yiakwy-xpu-ml-framework-team

@miladm paged attention kernel will be eliminated by flash attention both in prefill stage and decoding stage soon. In that case, memory block management will returned back to memory manager.

PageAttention is a mistake.

@sparsh35
Copy link

Any initial benchmarks for models like Gemma2 9b and 27b on TPU V5e or V4, considering switching ,
Hex LLM the container from google achieves like 4000 tok/s on tpu v5e 8 @WoosukKwon

@yiakwy-xpu-ml-framework-team

"To meet the XLA’s static shape requirement, we will bucketize the possible input shapes. ...to reduce the number of compiled graphs"

It is not XLA requirement, it is hardware requirement: if the hardware allocate memory in compile time, then the IR must populate the shape size for allocation optimization. A typical optimization is to use static memory as memory pool to allocate memory "dynamically".

However if your chip is GPU, then you are good to allocate memory just in need.

XLA provides bounded shape for the first case: https://github.com/pytorch/xla/blob/master/docs/dynamic_shape.md

@WoosukKwon

@yiakwy-xpu-ml-framework-team

Here is the WIP PR for the PagedAttention kernel on Pallas + TorchXLA: pytorch/xla#6912. We expect it to land pretty soon.

cc @wonjoolee95

Do you have any micro benchmark in TPU (static compilation with memory optimization) between paged attention and decodes with fix length of prefill tokens and decode tokens ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC tpu Related to Google TPUs
Projects
None yet
Development

No branches or pull requests

9 participants