Skip to content

Commit

Permalink
Add vLLM TPU example and update TPU readme
Browse files Browse the repository at this point in the history
Change optimum-tpu fork from bihan to dstack repo
  • Loading branch information
Bihan Rana authored and Bihan Rana committed Sep 1, 2024
1 parent 44c0da0 commit 96ad111
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 50 deletions.
11 changes: 11 additions & 0 deletions docs/overrides/examples.html
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ <h3>
Learn how to deploy and fine-tune LLMs on AMD.
</p>
</a>

<a href="/examples/accelerators/tpu"
class="feature-cell sky">
<h3>
TPU
</h3>

<p>
Learn how to deploy and fine-tune LLMs on TPU.
</p>
</a>
</div>

<div class="tx-landing__features_text">
Expand Down
10 changes: 10 additions & 0 deletions docs/overrides/home.html
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,16 @@ <h3>
</p>
</a>

<a href="/examples/accelerators/tpu" class="feature-cell sky">
<h3>
TPU
</h3>

<p>
Learn how to deploy and fine-tune LLMs on TPU.
</p>
</a>

<a href="/examples/llms/llama31" class="feature-cell sky">
<h3>
Llama 3.1
Expand Down
164 changes: 126 additions & 38 deletions examples/accelerators/tpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,114 @@ Below are a few examples on using TPUs for deployment and fine-tuning.
## Deployment

### Running as a service
You can use any serving framework, such as vLLM, TGI. Here's an example of a [service](https://dstack.ai/docs/services) that deploys
Llama 3.1 8B using [vLLM :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu){:target="_blank"} or
[Optimum TPU :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu){:target="_blank"}.

=== "vLLM"

<div editor-title="examples/deployment/vllm/service-tpu.dstack.yml">

```yaml
type: service
# The name is optional, if not specified, generated randomly
name: llama31-service-vLLM

env:
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct
- HUGGING_FACE_HUB_TOKEN
- DATE=20240828
- TORCH_VERSION=2.5.0
- VLLM_TARGET_DEVICE=tpu
- MAX_MODEL_LEN=4096

commands:
- pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
- git clone https://github.com/vllm-project/vllm.git
- cd vllm
- pip install -r requirements-tpu.txt
- apt-get install -y libopenblas-base libopenmpi-dev libomp-dev
- python setup.py develop
- vllm serve $MODEL_ID
--tensor-parallel-size 8
--max-model-len $MAX_MODEL_LEN
--port 8000

# Expose the vllm server port
port:
- 8000

spot_policy: auto

resources:
gpu: v5litepod-8

# (Optional) Enable the OpenAI-compatible endpoint
model:
format: openai
type: chat
name: meta-llama/Meta-Llama-3.1-8B
```
</div>

=== "Optimum TPU"

<div editor-title="examples/deployment/optimum-tpu/service.dstack.yml">

```yaml
type: service
name: llama31-service-optimum-tpu

# Using a custom Docker image; pending on https://github.com/huggingface/optimum-tpu/pull/87
image: sjbbihan/optimum-tpu:latest
env:
- HUGGING_FACE_HUB_TOKEN
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B
- MAX_TOTAL_TOKENS=4096
- MAX_BATCH_PREFILL_TOKENS=4095
commands:
- text-generation-launcher --port 8000
port: 8000

resources:
gpu: v5litepod-8

spot_policy: auto

model:
format: tgi
type: chat
name: meta-llama/Meta-Llama-3.1-8B
```
</div>

Note, for `Optimum TPU` by default `MAX_INPUT_TOKEN` is set to 4095, consequently we must set `MAX_BATCH_PREFILL_TOKENS` to 4095.
??? info "Docker image"
The official Docker image `huggingface/optimum-tpu:latest` doesn’t support Llama 3.1-8B.
We’ve created a custom image with the fix: `sjbbihan/optimum-tpu:latest`.
Once the [pull request :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu/pull/87){:target="_blank"} is merged,
the official Docker image can be used.

Here's an example of a [service](https://dstack.ai/docs/services) that deploys
Llama 3.1 8B using [Optimum TPU :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu){:target="_blank"}.

<div editor-title="examples/deployment/optimum-tpu/service.dstack.yml">

```yaml
type: service
name: llama31-service-optimum-tpu

# Using a custom Docker image; pending on https://github.com/huggingface/optimum-tpu/pull/85
image: sjbbihan/optimum-tpu:latest
env:
- HUGGING_FACE_HUB_TOKEN
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B
- MAX_CONCURRENT_REQUESTS=4
- MAX_INPUT_TOKENS=128
- MAX_TOTAL_TOKENS=150 # should be equal to max-input-token + max_new_tokens
- MAX_BATCH_PREFILL_TOKENS=128 # should be equal to max-input-tokens
commands:
- text-generation-launcher --port 8000
port: 8000
### Memory requirements

resources:
gpu: v5litepod-8
Below are the approximate memory requirements for serving LLMs with their corresponding TPUs.

spot_policy: auto
| Model size | bfloat16 | TPU |
|------------|----------|--------------|
| **8B** | 16GB | v5litepod-8 |
| **70B** | 140GB | v5litepod-16 |
| **405B** | 810GB | v5litepod-64 |
Note, TPU v5litepod is optimized for serving transformer-based models. Each core within the v5litepod is equipped with 16GB of memory.

model:
format: tgi
type: chat
name: meta-llama/Meta-Llama-3.1-8B
```
</div>
### Supported Framework

??? info "Docker image"
The official Docker image `huggingface/optimum-tpu:latest` doesn’t support Llama 3.1-8B.
We’ve created a custom image with the fix: `sjbbihan/optimum-tpu:latest`.
Once the [pull request :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu/pull/85){:target="_blank"} is merged,
the official Docker image can be used.
| Framework | Quantization | Note |
|-----------|----------------|-------------------------------------------------|
| **TGI** | bfloat16 | To deploy with TGI, Optimum-tpu is recommended. |
| **vLLM** | int8, bfloat16 | |

### Running a configuration

Expand All @@ -74,14 +142,15 @@ python: "3.11"
env:
- HUGGING_FACE_HUB_TOKEN
commands:
- git clone https://github.com/Bihan/optimum-tpu.git
- git clone -b add_llama_31_support https://github.com/dstackai/optimum-tpu.git
- mkdir -p optimum-tpu/examples/custom/
- cp examples/fine-tuning/optimum-tpu/llama31/train.py optimum-tpu/examples/custom/train.py
- cp examples/fine-tuning/optimum-tpu/llama31/config.yaml optimum-tpu/examples/custom/config.yaml
- cd optimum-tpu
- pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install datasets evaluate
- pip install accelerate -U
- pip install peft
- python examples/custom/train.py examples/custom/config.yaml
ports:
- 6006
Expand All @@ -95,6 +164,24 @@ resources:
[//]: # (### Fine-Tuning with TRL)
[//]: # (Use the example `examples/fine-tuning/optimum-tpu/gemma/train.dstack.yml` to Finetune `Gemma-2B` model using `trl` with `dstack` and `optimum-tpu`. )

### Memory requirements

Below are the approximate memory requirements for fine-tuning LLMs with their corresponding TPUs.

| Model size | LoRA | TPU |
|------------|-------|--------------|
| **8B** | 16GB | v5litepod-8 |
| **70B** | 160GB | v5litepod-16 |
| **405B** | 950GB | v5litepod-64 |
Note, TPU v5litepod is optimized for fine-tuning transformer-based models. Each core within the v5litepod is equipped with 16GB of memory.

### Supported Framework

| Framework | Quantization | Note |
|-----------------|--------------|-------------------------------------------------------------------------------------|
| **Trl** | bfloat16 | To fine-tune using Trl, Optimum-tpu is recommended. Llama 3.1 is not yet supported. |
| **Pytorch XLA** | bfloat16 | |

## Dev environments

Before running a task or service, it's recommended that you first start with
Expand All @@ -115,7 +202,8 @@ or send a [pull request :material-arrow-top-right-thin:{ .external }](https://gi

## What's next?

1. Browse [Optimum TPU :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu) and
[Optimum TPU TGI :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
1. Browse [Optimum TPU :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu),
[Optimum TPU TGI :material-arrow-top-right-thin:{ .external }](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference) and
[vLLM :material-arrow-top-right-thin:{ .external }](https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html).
2. Check [dev environments](https://dstack.ai/docs/dev-environments), [tasks](https://dstack.ai/docs/tasks),
[services](https://dstack.ai/docs/services), and [fleets](https://dstack.ai/docs/fleets).
2 changes: 1 addition & 1 deletion examples/deployment/optimum-tpu/.dstack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ type: dev-environment
name: vscode-optimum-tpu

# Using a Docker image with a fix instead of the official one
# More details at https://github.com/huggingface/optimum-tpu/pull/85
# More details at https://github.com/huggingface/optimum-tpu/pull/87
image: sjbbihan/optimum-tpu:latest
# Required environment variables
env:
Expand Down
4 changes: 2 additions & 2 deletions examples/deployment/optimum-tpu/service.dstack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ type: service
name: llama31-service-optimum-tpu

# Using a Docker image with a fix instead of the official one
# More details at https://github.com/huggingface/optimum-tpu/pull/85
# More details at https://github.com/huggingface/optimum-tpu/pull/87
image: sjbbihan/optimum-tpu:latest
# Required environment variables
env:
- HUGGING_FACE_HUB_TOKEN
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct
- MAX_CONCURRENT_REQUESTS=4
- MAX_INPUT_TOKENS=128
- MAX_TOTAL_TOKENS=150
Expand Down
2 changes: 1 addition & 1 deletion examples/deployment/optimum-tpu/task.dstack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ type: task
name: llama31-task-optimum-tpu

# Using a Docker image with a fix instead of the official one
# More details at https://github.com/huggingface/optimum-tpu/pull/85
# More details at https://github.com/huggingface/optimum-tpu/pull/87
image: sjbbihan/optimum-tpu:latest
# Required environment variables
env:
Expand Down
41 changes: 41 additions & 0 deletions examples/deployment/vllm/service-tpu.dstack.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
type: service
# The name is optional, if not specified, generated randomly
name: llama31-service-vLLM

env:
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct
- HUGGING_FACE_HUB_TOKEN
- DATE=20240828
- TORCH_VERSION=2.5.0
- VLLM_TARGET_DEVICE=tpu
- MAX_MODEL_LEN=4096

commands:
- pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
- git clone https://github.com/vllm-project/vllm.git
- cd vllm
- pip install -r requirements-tpu.txt
- apt-get install -y libopenblas-base libopenmpi-dev libomp-dev
- python setup.py develop
- vllm serve $MODEL_ID
--tensor-parallel-size 8
--max-model-len $MAX_MODEL_LEN
--port 8000

# Expose the vllm server port
port:
- 8000

spot_policy: auto

resources:
gpu: v5litepod-8

# (Optional) Enable the OpenAI-compatible endpoint
model:
format: openai
type: chat
name: meta-llama/Meta-Llama-3.1-8B
3 changes: 1 addition & 2 deletions examples/fine-tuning/optimum-tpu/gemma/train.dstack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ commands:
- pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install trl peft
- python examples/custom/train.py examples/custom/config.yaml
ports:
- 6006


resources:
gpu: v5litepod-8
3 changes: 2 additions & 1 deletion examples/fine-tuning/optimum-tpu/llama31/.dstack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ env:
# Refer to Note section in examples/gpus/tpu/README.md for more information about the optimum-tpu repository.
# Uncomment if you want the environment to be pre-installed
#init:
# - git clone https://github.com/Bihan/optimum-tpu.git
# - git clone -b add_llama_31_support https://github.com/dstackai/optimum-tpu.git
# - mkdir -p optimum-tpu/examples/custom/
# - cp examples/fine-tuning/optimum-tpu/llama31/train.py optimum-tpu/examples/custom/train.py
# - cp examples/fine-tuning/optimum-tpu/llama31/config.yaml optimum-tpu/examples/custom/config.yaml
# - cd optimum-tpu
# - pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html
# - pip install datasets evaluate
# - pip install accelerate -U
# - pip install peft

ide: vscode

Expand Down
1 change: 1 addition & 0 deletions examples/fine-tuning/optimum-tpu/llama31/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ output_dir: "./finetuned_models/llama3_fine_tuned"
optim: "adafactor"
dataset_name: "Abirate/english_quotes"
model_name: "meta-llama/Meta-Llama-3.1-8B"
lora_r: 4
push_to_hub: True
5 changes: 2 additions & 3 deletions examples/fine-tuning/optimum-tpu/llama31/train.dstack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@ env:

# Commands of the task
commands:
- git clone https://github.com/Bihan/optimum-tpu.git
- git clone -b add_llama_31_support https://github.com/dstackai/optimum-tpu.git
- mkdir -p optimum-tpu/examples/custom/
- cp examples/fine-tuning/optimum-tpu/llama31/train.py optimum-tpu/examples/custom/train.py
- cp examples/fine-tuning/optimum-tpu/llama31/config.yaml optimum-tpu/examples/custom/config.yaml
- cd optimum-tpu
- pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install datasets evaluate
- pip install accelerate -U
- pip install peft
- python examples/custom/train.py examples/custom/config.yaml
ports:
- 6006

resources:
gpu: v5litepod-8
Loading

0 comments on commit 96ad111

Please sign in to comment.