Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware] Initial TPU integration #5292

Merged
merged 193 commits into from
Jun 12, 2024
Merged
Changes from 1 commit
Commits
Show all changes
193 commits
Select commit Hold shift + click to select a range
52a1e90
Add TPU gemma
WoosukKwon Apr 1, 2024
86f073e
Add reference
WoosukKwon Apr 1, 2024
d148c2e
Add requirements
WoosukKwon Apr 1, 2024
3b8f430
Add is_tpu
WoosukKwon Apr 1, 2024
824521c
Add TPU to DeviceConfig
WoosukKwon Apr 1, 2024
5083aa9
Add TPUExecutor
WoosukKwon Apr 1, 2024
27c592b
Add get_dtype_size
WoosukKwon Apr 1, 2024
4cdb732
Add TPU to setup
WoosukKwon Apr 1, 2024
31d05f7
yapf
WoosukKwon Apr 1, 2024
46b31ed
Fix RoPE output shape
WoosukKwon Apr 1, 2024
02e614d
[WIP] Add Pallas backend
WoosukKwon Apr 1, 2024
38e3d33
Add TPU to device config
WoosukKwon Apr 1, 2024
6894d3e
Add JAX to requirements.txt
WoosukKwon Apr 1, 2024
d899009
[WIP] Add TPU worker
WoosukKwon Apr 1, 2024
60ff6b8
Merge branch 'main' into woosuk-tpu
WoosukKwon Apr 10, 2024
0d6402d
Fix requirements
WoosukKwon Apr 10, 2024
696b653
yapf
WoosukKwon Apr 10, 2024
363e6a9
Fix flashattn
WoosukKwon Apr 10, 2024
d4adf92
Merge branch 'main' into woosuk-tpu
WoosukKwon Apr 16, 2024
c59c1e7
Remove
WoosukKwon Apr 16, 2024
eb0a046
Add JAX requirements
WoosukKwon Apr 16, 2024
6692a30
Minor
WoosukKwon Apr 16, 2024
b3b89cf
Renew TPU executor
WoosukKwon Apr 16, 2024
de82e95
Minor
WoosukKwon Apr 16, 2024
6d62e4c
Add torch to dependencies
WoosukKwon Apr 16, 2024
91b47e3
JAX-based TPU worker
WoosukKwon Apr 16, 2024
cedb670
Add gemma
WoosukKwon Apr 17, 2024
8888d1c
Fix logit indices
WoosukKwon Apr 17, 2024
6661c03
Add paged_attn op
WoosukKwon Apr 17, 2024
b25fcc0
Minor
WoosukKwon Apr 17, 2024
25bbc21
Minor
WoosukKwon Apr 17, 2024
5cb213c
Add flash-attn op
WoosukKwon Apr 17, 2024
e4377dd
Add model runner
WoosukKwon Apr 17, 2024
0fb07c0
Minor
WoosukKwon Apr 17, 2024
4880de3
Add attn_mask
WoosukKwon Apr 17, 2024
756c4e7
Add write_to_cache ops
WoosukKwon Apr 17, 2024
ef762cb
Write kV
WoosukKwon Apr 17, 2024
186c88c
explictly return new_kv_caches
WoosukKwon Apr 17, 2024
7e3a230
Fix paged_attn
WoosukKwon Apr 17, 2024
62b870f
Use FlashAttention kernel
WoosukKwon Apr 17, 2024
743695f
Fix write_to_kv_cache
WoosukKwon Apr 19, 2024
8428430
Minor
WoosukKwon Apr 19, 2024
092e3d6
Remove hardcoded path
WoosukKwon Apr 19, 2024
d5fb1c2
Fix JAX jit OOM
WoosukKwon Apr 24, 2024
620e764
Fix cache write
WoosukKwon Apr 24, 2024
f42b4c2
Include argmax to jit
WoosukKwon Apr 24, 2024
5323969
Increase #blocks
WoosukKwon Apr 24, 2024
e2c7ded
Minor
WoosukKwon Apr 25, 2024
81b8b81
Pad to avoid recompilation
WoosukKwon Apr 25, 2024
98eda57
Add timer
WoosukKwon Apr 25, 2024
b62170e
Fix scheduler
WoosukKwon Apr 25, 2024
fa5bacd
Add warmup
WoosukKwon Apr 25, 2024
028f528
Fix KV cache shape
WoosukKwon Apr 25, 2024
2aa9831
Minor
WoosukKwon Apr 25, 2024
21f35c2
Change version
WoosukKwon Apr 26, 2024
d2c6a32
Fix is_tpu
WoosukKwon Apr 26, 2024
aa09283
Format gemma.py
WoosukKwon Apr 26, 2024
d16a348
Add comment
WoosukKwon Apr 26, 2024
4ea41d0
yapf
WoosukKwon Apr 26, 2024
5ae2f81
Add warmup + formatting
WoosukKwon Apr 26, 2024
d830766
yapf
WoosukKwon Apr 26, 2024
8d072db
yapf
WoosukKwon Apr 26, 2024
85d4488
yapf
WoosukKwon Apr 26, 2024
d1591f0
Add op benchmark scripts
WoosukKwon Apr 26, 2024
b15db23
Add precompilation step
WoosukKwon Apr 26, 2024
57690a9
Fix bucketing
WoosukKwon Apr 26, 2024
707a5f6
Move JAX-smi to worker
WoosukKwon Apr 26, 2024
f6637db
Use persistent cache
WoosukKwon Apr 26, 2024
07be6ed
Improve benchmark
WoosukKwon Apr 26, 2024
278e8a1
Add tpu
WoosukKwon Apr 26, 2024
408ff49
Tune pages_per_compute_block
WoosukKwon Apr 26, 2024
3f6288c
Fix for binary cache
WoosukKwon Apr 26, 2024
98a3df0
Disable memory tracking
WoosukKwon Apr 26, 2024
881b884
Add block size
WoosukKwon Apr 27, 2024
c00ddd6
Add buffer donation to benchmark
WoosukKwon Apr 30, 2024
74702d3
Merge branch 'main' into torch-xla
WoosukKwon Apr 30, 2024
3427a8f
Add benchmark_index_copy
WoosukKwon May 1, 2024
4f9dace
Update
WoosukKwon May 5, 2024
04738c9
Fix
WoosukKwon May 5, 2024
01b6f4a
Minor
WoosukKwon May 6, 2024
7496584
Add Pallas backend
WoosukKwon May 6, 2024
5327bd0
Add torch-xla gemma
WoosukKwon May 6, 2024
52b8eb8
Minor fix
WoosukKwon May 6, 2024
39a900f
TPU worker & model runner
WoosukKwon May 6, 2024
f7df218
yapf
WoosukKwon May 6, 2024
8a3d495
Minor
WoosukKwon May 6, 2024
4aa7e7e
Fix
WoosukKwon May 6, 2024
2889253
Use logits to sample
WoosukKwon May 6, 2024
5047229
Scaling factor
WoosukKwon May 6, 2024
bf8cd8f
Minor
WoosukKwon May 6, 2024
0cefb98
Minor
WoosukKwon May 6, 2024
770c298
Minor
WoosukKwon May 6, 2024
a5a7709
fix
WoosukKwon May 7, 2024
b079c6a
Fix
WoosukKwon May 12, 2024
1509eb8
Add megacore_mode
WoosukKwon May 15, 2024
7747def
Fix
WoosukKwon May 16, 2024
f282252
Fix benchmark
WoosukKwon May 16, 2024
614b1b1
yapf
WoosukKwon May 16, 2024
335222d
yapf
WoosukKwon May 16, 2024
f754b67
Minor
WoosukKwon May 20, 2024
d9a6616
Fix megacore for mqa
WoosukKwon May 20, 2024
a98d618
Fix megacore
WoosukKwon May 29, 2024
841eef2
Minor
WoosukKwon May 29, 2024
d02025d
Fix torch compile error
WoosukKwon May 29, 2024
e0e252b
Add memory profiling
WoosukKwon May 29, 2024
e510c0d
Add CustomOp Interface
WoosukKwon Jun 4, 2024
d9d43a6
Move activation
WoosukKwon Jun 4, 2024
19bff1c
Move layernorm
WoosukKwon Jun 4, 2024
8bff05a
Move RoPE
WoosukKwon Jun 4, 2024
af0d31e
Minor
WoosukKwon Jun 4, 2024
a631e7f
Fix
WoosukKwon Jun 4, 2024
a1486ff
Fix
WoosukKwon Jun 4, 2024
e135eae
Fix
WoosukKwon Jun 4, 2024
31e4930
Merge branch 'main' into dispatcher
WoosukKwon Jun 4, 2024
16bab8e
Revert model changes
WoosukKwon Jun 4, 2024
41b9a2a
move back
WoosukKwon Jun 4, 2024
7986c0f
forward_native
WoosukKwon Jun 4, 2024
24e11d2
revert
WoosukKwon Jun 4, 2024
cdc62a2
Move dispatch to offline
WoosukKwon Jun 4, 2024
d1182e7
Add note
WoosukKwon Jun 4, 2024
97a9949
Add compileable RoPE
WoosukKwon Jun 5, 2024
7ad432c
Merge branch 'main' into torch-xla
WoosukKwon Jun 5, 2024
d25c663
Merge remote-tracking branch 'origin/compilable-rope' into torch-xla
WoosukKwon Jun 5, 2024
7d79210
Remove JAX
WoosukKwon Jun 5, 2024
997f53c
Fix
WoosukKwon Jun 5, 2024
8035092
Minor
WoosukKwon Jun 5, 2024
6346708
Remove code
WoosukKwon Jun 5, 2024
917f815
Works
WoosukKwon Jun 5, 2024
5918000
Fix
WoosukKwon Jun 5, 2024
9a4ad83
mypy
WoosukKwon Jun 5, 2024
2d8a411
yapf
WoosukKwon Jun 5, 2024
972744a
yapf
WoosukKwon Jun 5, 2024
de08c61
yapf
WoosukKwon Jun 5, 2024
136c1c1
mypy
WoosukKwon Jun 5, 2024
22564e0
Move cache size config
WoosukKwon Jun 5, 2024
3d002ff
Support temp
WoosukKwon Jun 5, 2024
99b93c7
Fix
WoosukKwon Jun 5, 2024
5173d4b
is_tpu
WoosukKwon Jun 5, 2024
d8939a3
init dist
WoosukKwon Jun 5, 2024
ca9283a
Model loader & yapf
WoosukKwon Jun 5, 2024
924ae82
Minor
WoosukKwon Jun 5, 2024
7e64a4d
Minor
WoosukKwon Jun 5, 2024
d4c494c
Use vLLM layers for gemma
WoosukKwon Jun 5, 2024
dbe83ac
Fix weight loading error
WoosukKwon Jun 5, 2024
1f28dd3
Bench normal attention
WoosukKwon Jun 5, 2024
57c36fa
yapf
WoosukKwon Jun 5, 2024
94c0f3f
Minor
WoosukKwon Jun 5, 2024
c5f9430
Remove
WoosukKwon Jun 5, 2024
2e8860a
Merge branch 'main' into torch-xla
WoosukKwon Jun 5, 2024
b0b42d8
Remove benchmarking scripts
WoosukKwon Jun 5, 2024
8819df9
Remove TPU models
WoosukKwon Jun 5, 2024
62e323e
Add requirements-tpu.txt
WoosukKwon Jun 5, 2024
6475b54
Compatible with compute_logits
WoosukKwon Jun 5, 2024
c7dc9e5
Minor
WoosukKwon Jun 5, 2024
1330c93
Minor
WoosukKwon Jun 5, 2024
908470c
yapf
WoosukKwon Jun 5, 2024
87e9a71
Fix mor MQA
WoosukKwon Jun 6, 2024
bb7c720
Refactor bucketing
WoosukKwon Jun 6, 2024
b899708
yapf
WoosukKwon Jun 6, 2024
96b20d6
Minor
WoosukKwon Jun 6, 2024
84e4c51
Add padding to t, p
WoosukKwon Jun 6, 2024
dc02c01
Fix GQA
WoosukKwon Jun 6, 2024
6875593
yapf
WoosukKwon Jun 6, 2024
77f80fc
Consider program size
WoosukKwon Jun 6, 2024
ee01196
Add top-p sampling
WoosukKwon Jun 6, 2024
e881c1c
MInor
WoosukKwon Jun 6, 2024
f0d3ac9
Disable top-p sampling
WoosukKwon Jun 6, 2024
0393aee
Fix model loading
WoosukKwon Jun 8, 2024
ae967ab
Fix setup.py
WoosukKwon Jun 8, 2024
11c0fa7
Add tpu to latency
WoosukKwon Jun 8, 2024
10240c8
Remove mark.step
WoosukKwon Jun 8, 2024
36ac127
Refactor RoPE
WoosukKwon Jun 8, 2024
90d1e31
Merge branch 'main' into torch-xla
WoosukKwon Jun 8, 2024
005343a
Revert back
WoosukKwon Jun 9, 2024
2024319
Remove benchmark
WoosukKwon Jun 9, 2024
51b2ac7
Fix
WoosukKwon Jun 9, 2024
05e7261
Add XLA cache env variable
WoosukKwon Jun 9, 2024
1d12943
Add TPU dockerfile
WoosukKwon Jun 9, 2024
89ea3aa
Fix requirements-tpu.txt
WoosukKwon Jun 9, 2024
eaf9352
Add TPU docs
WoosukKwon Jun 9, 2024
089476e
Remove tpu-install.sh
WoosukKwon Jun 9, 2024
fa10ec6
Fix docs
WoosukKwon Jun 9, 2024
3d111f1
Remove TODO
WoosukKwon Jun 9, 2024
0e0de1c
Add NotImplementedError
WoosukKwon Jun 10, 2024
c602d78
Enable top-p sampling
WoosukKwon Jun 11, 2024
c56d6ba
Fix RoPE
WoosukKwon Jun 11, 2024
8820d06
Disable top-p sampling
WoosukKwon Jun 11, 2024
205820d
Remove scheduler hack
WoosukKwon Jun 11, 2024
cb5e4f6
Use enforce-eager to skip warmup
WoosukKwon Jun 11, 2024
4be5a3c
Merge branch 'main' into torch-xla
WoosukKwon Jun 11, 2024
b4aa403
Address comments
WoosukKwon Jun 12, 2024
034b9bd
Fix for v5p
WoosukKwon Jun 12, 2024
f5e1bf5
Add build dependencies
WoosukKwon Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix
  • Loading branch information
WoosukKwon committed May 5, 2024
commit 04738c9d63c1c25f929074ac9f49321e234d2cc6
8 changes: 6 additions & 2 deletions benchmarks/kernels/benchmark_index_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ def write_to_kv_cache(
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
k_cache = k_cache.flatten(0, 1)
key = key.flatten(0, 1)
k_cache = k_cache.index_copy_(0, slot_mapping, key)
v_cache = v_cache.flatten(0, 1)
value = value.flatten(0, 1)
v_cache = v_cache.index_copy_(0, slot_mapping, value)


def benchmark(num_blocks: int):
key = torch.randn(BATCH_SIZE * SEQ_LEN, NUM_KV_HEADS, HEAD_SIZE, device=device, dtype=DTYPE)
k_cache = torch.randn(num_blocks * BLOCK_SIZE, NUM_KV_HEADS, HEAD_SIZE, device=device, dtype=DTYPE)
key = torch.randn(BATCH_SIZE, SEQ_LEN, NUM_KV_HEADS, HEAD_SIZE, device=device, dtype=DTYPE)
k_cache = torch.randn(num_blocks, BLOCK_SIZE, NUM_KV_HEADS, HEAD_SIZE, device=device, dtype=DTYPE)
value = torch.randn_like(key)
v_cache = torch.randn_like(k_cache)
slot_mapping = torch.randint(0, num_blocks, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.int64)
Expand Down