-
-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hardware] Initial TPU integration #5292
Merged
Merged
Changes from 1 commit
Commits
Show all changes
193 commits
Select commit
Hold shift + click to select a range
52a1e90
Add TPU gemma
WoosukKwon 86f073e
Add reference
WoosukKwon d148c2e
Add requirements
WoosukKwon 3b8f430
Add is_tpu
WoosukKwon 824521c
Add TPU to DeviceConfig
WoosukKwon 5083aa9
Add TPUExecutor
WoosukKwon 27c592b
Add get_dtype_size
WoosukKwon 4cdb732
Add TPU to setup
WoosukKwon 31d05f7
yapf
WoosukKwon 46b31ed
Fix RoPE output shape
WoosukKwon 02e614d
[WIP] Add Pallas backend
WoosukKwon 38e3d33
Add TPU to device config
WoosukKwon 6894d3e
Add JAX to requirements.txt
WoosukKwon d899009
[WIP] Add TPU worker
WoosukKwon 60ff6b8
Merge branch 'main' into woosuk-tpu
WoosukKwon 0d6402d
Fix requirements
WoosukKwon 696b653
yapf
WoosukKwon 363e6a9
Fix flashattn
WoosukKwon d4adf92
Merge branch 'main' into woosuk-tpu
WoosukKwon c59c1e7
Remove
WoosukKwon eb0a046
Add JAX requirements
WoosukKwon 6692a30
Minor
WoosukKwon b3b89cf
Renew TPU executor
WoosukKwon de82e95
Minor
WoosukKwon 6d62e4c
Add torch to dependencies
WoosukKwon 91b47e3
JAX-based TPU worker
WoosukKwon cedb670
Add gemma
WoosukKwon 8888d1c
Fix logit indices
WoosukKwon 6661c03
Add paged_attn op
WoosukKwon b25fcc0
Minor
WoosukKwon 25bbc21
Minor
WoosukKwon 5cb213c
Add flash-attn op
WoosukKwon e4377dd
Add model runner
WoosukKwon 0fb07c0
Minor
WoosukKwon 4880de3
Add attn_mask
WoosukKwon 756c4e7
Add write_to_cache ops
WoosukKwon ef762cb
Write kV
WoosukKwon 186c88c
explictly return new_kv_caches
WoosukKwon 7e3a230
Fix paged_attn
WoosukKwon 62b870f
Use FlashAttention kernel
WoosukKwon 743695f
Fix write_to_kv_cache
WoosukKwon 8428430
Minor
WoosukKwon 092e3d6
Remove hardcoded path
WoosukKwon d5fb1c2
Fix JAX jit OOM
WoosukKwon 620e764
Fix cache write
WoosukKwon f42b4c2
Include argmax to jit
WoosukKwon 5323969
Increase #blocks
WoosukKwon e2c7ded
Minor
WoosukKwon 81b8b81
Pad to avoid recompilation
WoosukKwon 98eda57
Add timer
WoosukKwon b62170e
Fix scheduler
WoosukKwon fa5bacd
Add warmup
WoosukKwon 028f528
Fix KV cache shape
WoosukKwon 2aa9831
Minor
WoosukKwon 21f35c2
Change version
WoosukKwon d2c6a32
Fix is_tpu
WoosukKwon aa09283
Format gemma.py
WoosukKwon d16a348
Add comment
WoosukKwon 4ea41d0
yapf
WoosukKwon 5ae2f81
Add warmup + formatting
WoosukKwon d830766
yapf
WoosukKwon 8d072db
yapf
WoosukKwon 85d4488
yapf
WoosukKwon d1591f0
Add op benchmark scripts
WoosukKwon b15db23
Add precompilation step
WoosukKwon 57690a9
Fix bucketing
WoosukKwon 707a5f6
Move JAX-smi to worker
WoosukKwon f6637db
Use persistent cache
WoosukKwon 07be6ed
Improve benchmark
WoosukKwon 278e8a1
Add tpu
WoosukKwon 408ff49
Tune pages_per_compute_block
WoosukKwon 3f6288c
Fix for binary cache
WoosukKwon 98a3df0
Disable memory tracking
WoosukKwon 881b884
Add block size
WoosukKwon c00ddd6
Add buffer donation to benchmark
WoosukKwon 74702d3
Merge branch 'main' into torch-xla
WoosukKwon 3427a8f
Add benchmark_index_copy
WoosukKwon 4f9dace
Update
WoosukKwon 04738c9
Fix
WoosukKwon 01b6f4a
Minor
WoosukKwon 7496584
Add Pallas backend
WoosukKwon 5327bd0
Add torch-xla gemma
WoosukKwon 52b8eb8
Minor fix
WoosukKwon 39a900f
TPU worker & model runner
WoosukKwon f7df218
yapf
WoosukKwon 8a3d495
Minor
WoosukKwon 4aa7e7e
Fix
WoosukKwon 2889253
Use logits to sample
WoosukKwon 5047229
Scaling factor
WoosukKwon bf8cd8f
Minor
WoosukKwon 0cefb98
Minor
WoosukKwon 770c298
Minor
WoosukKwon a5a7709
fix
WoosukKwon b079c6a
Fix
WoosukKwon 1509eb8
Add megacore_mode
WoosukKwon 7747def
Fix
WoosukKwon f282252
Fix benchmark
WoosukKwon 614b1b1
yapf
WoosukKwon 335222d
yapf
WoosukKwon f754b67
Minor
WoosukKwon d9a6616
Fix megacore for mqa
WoosukKwon a98d618
Fix megacore
WoosukKwon 841eef2
Minor
WoosukKwon d02025d
Fix torch compile error
WoosukKwon e0e252b
Add memory profiling
WoosukKwon e510c0d
Add CustomOp Interface
WoosukKwon d9d43a6
Move activation
WoosukKwon 19bff1c
Move layernorm
WoosukKwon 8bff05a
Move RoPE
WoosukKwon af0d31e
Minor
WoosukKwon a631e7f
Fix
WoosukKwon a1486ff
Fix
WoosukKwon e135eae
Fix
WoosukKwon 31e4930
Merge branch 'main' into dispatcher
WoosukKwon 16bab8e
Revert model changes
WoosukKwon 41b9a2a
move back
WoosukKwon 7986c0f
forward_native
WoosukKwon 24e11d2
revert
WoosukKwon cdc62a2
Move dispatch to offline
WoosukKwon d1182e7
Add note
WoosukKwon 97a9949
Add compileable RoPE
WoosukKwon 7ad432c
Merge branch 'main' into torch-xla
WoosukKwon d25c663
Merge remote-tracking branch 'origin/compilable-rope' into torch-xla
WoosukKwon 7d79210
Remove JAX
WoosukKwon 997f53c
Fix
WoosukKwon 8035092
Minor
WoosukKwon 6346708
Remove code
WoosukKwon 917f815
Works
WoosukKwon 5918000
Fix
WoosukKwon 9a4ad83
mypy
WoosukKwon 2d8a411
yapf
WoosukKwon 972744a
yapf
WoosukKwon de08c61
yapf
WoosukKwon 136c1c1
mypy
WoosukKwon 22564e0
Move cache size config
WoosukKwon 3d002ff
Support temp
WoosukKwon 99b93c7
Fix
WoosukKwon 5173d4b
is_tpu
WoosukKwon d8939a3
init dist
WoosukKwon ca9283a
Model loader & yapf
WoosukKwon 924ae82
Minor
WoosukKwon 7e64a4d
Minor
WoosukKwon d4c494c
Use vLLM layers for gemma
WoosukKwon dbe83ac
Fix weight loading error
WoosukKwon 1f28dd3
Bench normal attention
WoosukKwon 57c36fa
yapf
WoosukKwon 94c0f3f
Minor
WoosukKwon c5f9430
Remove
WoosukKwon 2e8860a
Merge branch 'main' into torch-xla
WoosukKwon b0b42d8
Remove benchmarking scripts
WoosukKwon 8819df9
Remove TPU models
WoosukKwon 62e323e
Add requirements-tpu.txt
WoosukKwon 6475b54
Compatible with compute_logits
WoosukKwon c7dc9e5
Minor
WoosukKwon 1330c93
Minor
WoosukKwon 908470c
yapf
WoosukKwon 87e9a71
Fix mor MQA
WoosukKwon bb7c720
Refactor bucketing
WoosukKwon b899708
yapf
WoosukKwon 96b20d6
Minor
WoosukKwon 84e4c51
Add padding to t, p
WoosukKwon dc02c01
Fix GQA
WoosukKwon 6875593
yapf
WoosukKwon 77f80fc
Consider program size
WoosukKwon ee01196
Add top-p sampling
WoosukKwon e881c1c
MInor
WoosukKwon f0d3ac9
Disable top-p sampling
WoosukKwon 0393aee
Fix model loading
WoosukKwon ae967ab
Fix setup.py
WoosukKwon 11c0fa7
Add tpu to latency
WoosukKwon 10240c8
Remove mark.step
WoosukKwon 36ac127
Refactor RoPE
WoosukKwon 90d1e31
Merge branch 'main' into torch-xla
WoosukKwon 005343a
Revert back
WoosukKwon 2024319
Remove benchmark
WoosukKwon 51b2ac7
Fix
WoosukKwon 05e7261
Add XLA cache env variable
WoosukKwon 1d12943
Add TPU dockerfile
WoosukKwon 89ea3aa
Fix requirements-tpu.txt
WoosukKwon eaf9352
Add TPU docs
WoosukKwon 089476e
Remove tpu-install.sh
WoosukKwon fa10ec6
Fix docs
WoosukKwon 3d111f1
Remove TODO
WoosukKwon 0e0de1c
Add NotImplementedError
WoosukKwon c602d78
Enable top-p sampling
WoosukKwon c56d6ba
Fix RoPE
WoosukKwon 8820d06
Disable top-p sampling
WoosukKwon 205820d
Remove scheduler hack
WoosukKwon cb5e4f6
Use enforce-eager to skip warmup
WoosukKwon 4be5a3c
Merge branch 'main' into torch-xla
WoosukKwon b4aa403
Address comments
WoosukKwon 034b9bd
Fix for v5p
WoosukKwon f5e1bf5
Add build dependencies
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add top-p sampling
- Loading branch information
commit ee01196502436b6984e78c8f40f2a12fea0caac8
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this function used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for top-p sampling. Basically, we need to mask out the logits of the tokens that do not belong to the top p%. While this function works, I found that it slowed down the TPU performance by ~10x. So I disabled it for now. I can also delete it.