Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin #7701

Merged
merged 27 commits into from
Sep 23, 2024

Conversation

LucasWilkinson
Copy link
Contributor

@LucasWilkinson LucasWilkinson commented Aug 20, 2024

Add Machete kernels as a backend kernel for GPTQMarlinLinearMethod and CompressedTensorsWNA16. As part of adding support GPTQMarlinLinearMethod support for dynamic group ids (g_idx, i.e. colloquially known and actorder) was added by permuting the columns of the activation tensor before calling machete.

This PR also contains an updated heuristic for Machete helping improve performance for larger GEMMs.

Benchmarking Results (H100)

Llama 3.1 70b (Tensor Parallelism = 1)

neuralmagic/Meta-Llama-3.1-70B-Instruct-quantized.w4a16

Llama 3 1 70b, tp=1_output_throughput
Llama 3 1 70b, tp=1_input_throughput
Llama 3 1 70b, tp=1_median_ttft_ms-2
Llama 3 1 70b, tp=1_median_tpot_ms

Llama 3.1 405b (Tensor Parallelism = 4)

neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w4a16

Llama 3 1 405b, tp=4_output_throughput
Llama 3 1 405b, tp=4_input_throughput
Llama 3 1 405b, tp=4_median_ttft_ms
Llama 3 1 405b, tp=4_median_tpot_ms

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/machete-end2end branch 3 times, most recently from 1361be1 to d5ee5b8 Compare August 20, 2024 19:16
@LucasWilkinson LucasWilkinson changed the title [WIP, Kernel] (2/N) Machete - Integrate into GPTQMarlinLinearMethod and CompressedTensorsWNA16 [WIP, Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 Aug 30, 2024
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/machete-end2end branch 2 times, most recently from 953973d to 90f8bb6 Compare September 10, 2024 21:08
@LucasWilkinson LucasWilkinson changed the title [WIP, Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 [Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin Sep 11, 2024
@LucasWilkinson LucasWilkinson marked this pull request as ready for review September 11, 2024 20:07
@LucasWilkinson
Copy link
Contributor Author

\ready

@LucasWilkinson
Copy link
Contributor Author

@mgoin @dsikka

@@ -328,6 +328,64 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
marlin_tile_size=self.marlin_tile_size)


def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current lifecycle, is this only meant to be called post weight loading/in process_weight_after_loading?

Copy link
Contributor Author

@LucasWilkinson LucasWilkinson Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya, its meant to try to move the param into a more standard layout so that the number of repacking kernels is reduced, basically it does:

standard_layout = {packed_dim: 0, input_dim: 0, output_dim: 1, packed_order: "contiguous}
CT -> standard_layout
GPTQ -> standard_layout

so that we only need:

standard_layout -> Machete (i.e. `machete_prepack_B`)
standard_layout -> Marlin (i.e. `gptq_marlin_repack`)

eventually it would be nice to also have this support

AWQ -> standard_layout
QQQ -> standard_layout

so will likely be refactored in the near future once we finalize a new design for quantized linear stuff

Copy link
Contributor

@dsikka dsikka Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have this function return a torch.nn.Parameter? After weight loading, vLLMParameters are no longer needed and we also the data to be in a torch.nn.Parameter for torch.compile

e.g:

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

Copy link
Contributor Author

@LucasWilkinson LucasWilkinson Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after our conversation on Thursday I made this change as part of _transform_param, since im not really fan of giving that responsibility to a layout utility function (I could rename this function but I still like the idea of have a utility that purely manipulates the layout, I am open to having my mind changed)

that _transform_param change is here: c452a86

I tested torch.compile after feed back from @bnellnm and it ran fine, so I think we are good on that front!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thank you!

move heuristic into C++ code

fix unit tests + format

update for 3.5.1

remove custom scheduler

codespell

cleanup comment

cleanup diff

review comments

review comments

review comment changes

review comments

fix codespell

cleanup util logic

make dim names for prepack layout more canoncial

missed refactor

wip

interleaving + recasting

tweak tolerances

comments plus interleaving

format

codespell

review comments

end2end first pass

seperate out kernels, format

add machete as a gptq backend

update to use  ModelWeightParameter

formatting

update parameter.py

refactor permute layout

wip
Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in a super solid place, thanks for addressing the reviews well!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 19, 2024
@@ -328,6 +328,64 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
marlin_tile_size=self.marlin_tile_size)


def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
Copy link
Contributor

@dsikka dsikka Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have this function return a torch.nn.Parameter? After weight loading, vLLMParameters are no longer needed and we also the data to be in a torch.nn.Parameter for torch.compile

e.g:

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - just a couple of minor questions/comments

benchmarks/kernels/benchmark_machete.py Show resolved Hide resolved
vllm/_custom_ops.py Show resolved Hide resolved
@mgoin mgoin merged commit 86e9c8d into vllm-project:main Sep 23, 2024
74 checks passed
agt pushed a commit to agt/vllm that referenced this pull request Sep 24, 2024
…TQMarlin (vllm-project#7701)

Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Manikandan-Thangaraj-ZS0321 added a commit to Manikandan-Thangaraj-ZS0321/vllm that referenced this pull request Sep 25, 2024
* [Kernel] Enable 8-bit weights in Fused Marlin MoE (vllm-project#8032)

Co-authored-by: Dipika <dipikasikka1@gmail.com>

* [Frontend] Expose revision arg in OpenAI server (vllm-project#8501)

* [BugFix] Fix clean shutdown issues (vllm-project#8492)

* [Bugfix][Kernel] Fix build for sm_60 in GGUF kernel (vllm-project#8506)

* [Kernel] AQ AZP 3/4: Asymmetric quantization kernels (vllm-project#7270)

* [doc] update doc on testing and debugging (vllm-project#8514)

* [Bugfix] Bind api server port before starting engine (vllm-project#8491)

* [perf bench] set timeout to debug hanging (vllm-project#8516)

* [misc] small qol fixes for release process (vllm-project#8517)

* [Bugfix] Fix 3.12 builds on main (vllm-project#8510)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>

* [refactor] remove triton based sampler (vllm-project#8524)

* [Frontend] Improve Nullable kv Arg Parsing (vllm-project#8525)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* [Misc][Bugfix] Disable guided decoding for mistral tokenizer (vllm-project#8521)

* [torch.compile] register allreduce operations as custom ops (vllm-project#8526)

* [Misc] Limit to ray[adag] 2.35 to avoid backward incompatible change (vllm-project#8509)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>

* [Benchmark] Support sample from HF datasets and image input for benchmark_serving (vllm-project#8495)

* [Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (vllm-project#7631)

* [Feature][kernel] tensor parallelism with bitsandbytes quantization (vllm-project#8434)

* [Model] Add mistral function calling format to all models loaded with "mistral" format (vllm-project#8515)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Misc] Don't dump contents of kvcache tensors on errors (vllm-project#8527)

* [Bugfix] Fix TP > 1 for new granite (vllm-project#8544)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>

* [doc] improve installation doc (vllm-project#8550)

Co-authored-by: Andy Dai <76841985+Imss27@users.noreply.github.com>

* [CI/Build] Excluding kernels/test_gguf.py from ROCm (vllm-project#8520)

* [Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (vllm-project#8012)

* [CI/Build] fix Dockerfile.cpu on podman (vllm-project#8540)

* [Misc] Add argument to disable FastAPI docs (vllm-project#8554)

* [CI/Build] Avoid CUDA initialization (vllm-project#8534)

* [CI/Build] Update Ruff version (vllm-project#8469)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Core][Bugfix][Perf] Introduce `MQLLMEngine` to avoid `asyncio` OH (vllm-project#8157)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>

* [Core] *Prompt* logprobs support in Multi-step (vllm-project#8199)

* [Core] zmq: bind only to 127.0.0.1 for local-only usage (vllm-project#8543)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [Model] Support Solar Model (vllm-project#8386)

Co-authored-by: Michael Goin <michael@neuralmagic.com>

* [AMD][ROCm]Quantization methods on ROCm; Fix _scaled_mm call (vllm-project#8380)

Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>

* [Kernel] Change interface to Mamba selective_state_update for continuous batching (vllm-project#8039)

* [BugFix] Nonzero exit code if MQLLMEngine startup fails (vllm-project#8572)

* [Bugfix] add `dead_error` property to engine client (vllm-project#8574)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>

* [Kernel] Remove marlin moe templating on thread_m_blocks (vllm-project#8573)

Co-authored-by: lwilkinson@neuralmagic.com

* [Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction during decode of encoder-decoder models.  (vllm-project#8545)

* Revert "[Misc][Bugfix] Disable guided decoding for mistral tokenizer" (vllm-project#8593)

* [Bugfix] fixing sonnet benchmark bug in benchmark_serving.py (vllm-project#8616)

* [MISC] remove engine_use_ray in benchmark_throughput.py (vllm-project#8615)

* [Frontend] Use MQLLMEngine for embeddings models too (vllm-project#8584)

* [Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (vllm-project#8577)

* [Core] simplify logits resort in _apply_top_k_top_p (vllm-project#8619)

* [Doc] Add documentation for GGUF quantization (vllm-project#8618)

* Create SECURITY.md (vllm-project#8642)

* [CI/Build] Re-enabling Entrypoints tests on ROCm, excluding ones that fail (vllm-project#8551)

* [Misc] guard against change in cuda library name (vllm-project#8609)

* [Bugfix] Fix Phi3.5 mini and MoE LoRA inference (vllm-project#8571)

* [bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetadata (vllm-project#8474)

* [Core] Support Lora lineage and base model metadata management (vllm-project#6315)

* [Model] Add OLMoE (vllm-project#7922)

* [CI/Build] Removing entrypoints/openai/test_embedding.py test from ROCm build (vllm-project#8670)

* [Bugfix] Validate SamplingParam n is an int (vllm-project#8548)

* [Misc] Show AMD GPU topology in `collect_env.py` (vllm-project#8649)

* [Bugfix] Config got an unexpected keyword argument 'engine' (vllm-project#8556)

* [Bugfix][Core] Fix tekken edge case for mistral tokenizer (vllm-project#8640)

* [Doc] neuron documentation update (vllm-project#8671)

Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>

* [Hardware][AWS] update neuron to 2.20 (vllm-project#8676)

Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>

* [Bugfix] Fix incorrect llava next feature size calculation (vllm-project#8496)

* [Core] Rename `PromptInputs` and `inputs`(vllm-project#8673)

* [MISC] add support custom_op check (vllm-project#8557)

Co-authored-by: youkaichao <youkaichao@126.com>

* [Core] Factor out common code in `SequenceData` and `Sequence` (vllm-project#8675)

* [beam search] add output for manually checking the correctness (vllm-project#8684)

* [Kernel] Build flash-attn from source (vllm-project#8245)

* [VLM] Use `SequenceData.from_token_counts` to create dummy data (vllm-project#8687)

* [Doc] Fix typo in AMD installation guide (vllm-project#8689)

* [Kernel][Triton][AMD] Remove tl.atomic_add from awq_gemm_kernel, 2-5x speedup MI300, minor improvement for MI250 (vllm-project#8646)

* [dbrx] refactor dbrx experts to extend FusedMoe class (vllm-project#8518)

* [Kernel][Bugfix] Delete some more useless code in marlin_moe_ops.cu (vllm-project#8643)

* [Bugfix] Refactor composite weight loading logic (vllm-project#8656)

* [ci][build] fix vllm-flash-attn (vllm-project#8699)

* [Model] Refactor BLIP/BLIP-2 to support composite model loading (vllm-project#8407)

* [Misc] Use NamedTuple in Multi-image example (vllm-project#8705)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* [MISC] rename CudaMemoryProfiler to DeviceMemoryProfiler (vllm-project#8703)

* [Model][VLM] Add LLaVA-Onevision model support (vllm-project#8486)

Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [SpecDec][Misc] Cleanup, remove bonus token logic. (vllm-project#8701)

* [build] enable existing pytorch (for GH200, aarch64, nightly) (vllm-project#8713)

* [misc] upgrade mistral-common (vllm-project#8715)

* [Bugfix] Avoid some bogus messages RE CUTLASS's revision when building (vllm-project#8702)

* [Bugfix] Fix CPU CMake build (vllm-project#8723)

Co-authored-by: Yuan <yuan.zhou@intel.com>

* [Bugfix] fix docker build for xpu (vllm-project#8652)

* [Core][Frontend] Support Passing Multimodal Processor Kwargs (vllm-project#8657)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* [Hardware][CPU] Refactor CPU model runner (vllm-project#8729)

* [Bugfix][CPU] fix missing input intermediate_tensors in the cpu_model_runner (vllm-project#8733)

* [Model] Support pp for qwen2-vl (vllm-project#8696)

* [VLM] Fix paligemma, fuyu and persimmon with transformers 4.45 : use config.text_config.vocab_size (vllm-project#8707)

* [CI/Build] use setuptools-scm to set __version__ (vllm-project#4738)

Co-authored-by: youkaichao <youkaichao@126.com>

* [Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (vllm-project#7701)

Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>

* [Kernel][LoRA]  Add assertion for punica sgmv kernels (vllm-project#7585)

* [Core] Allow IPv6 in VLLM_HOST_IP with zmq (vllm-project#8575)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* Fix typical acceptance sampler with correct recovered token ids (vllm-project#8562)

* Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (vllm-project#8335)

* [Hardware][AMD] ROCm6.2 upgrade (vllm-project#8674)

* Fix tests in test_scheduler.py that fail with BlockManager V2 (vllm-project#8728)

* re-implement beam search on top of vllm core (vllm-project#8726)

Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>

* Revert "[Core] Rename `PromptInputs` to `PromptType`, and `inputs` to `prompt`" (vllm-project#8750)

* [MISC] Skip dumping inputs when unpicklable (vllm-project#8744)

* [Core][Model] Support loading weights by ID within models (vllm-project#7931)

* [Model] Expose Phi3v num_crops as a mm_processor_kwarg (vllm-project#8658)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Fix potentially unsafe custom allreduce synchronization (vllm-project#8558)

* [Kernel] Split Marlin MoE kernels into multiple files (vllm-project#8661)

Co-authored-by: mgoin <michael@neuralmagic.com>

* [Frontend] Batch inference for llm.chat() API  (vllm-project#8648)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>

* [Bugfix] Fix torch dynamo fixes caused by `replace_parameters` (vllm-project#8748)

* [CI/Build] fix setuptools-scm usage (vllm-project#8771)

* [misc] soft drop beam search (vllm-project#8763)

* [[Misc]Upgrade bitsandbytes to the latest version 0.44.0 (vllm-project#8768)

* [Core][Bugfix] Support prompt_logprobs returned with speculative decoding (vllm-project#8047)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* [Core] Adding Priority Scheduling (vllm-project#5958)

* [Bugfix] Use heartbeats instead of health checks (vllm-project#8583)

* Fix test_schedule_swapped_simple in test_scheduler.py (vllm-project#8780)

* [Bugfix][Kernel] Implement acquire/release polyfill for Pascal (vllm-project#8776)

* Fix tests in test_chunked_prefill_scheduler which fail with BlockManager V2 (vllm-project#8752)

* [BugFix] Propagate 'trust_remote_code' setting in internvl and minicpmv (vllm-project#8250)

* [Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (vllm-project#8770)

* [Bugfix] load fc bias from config for eagle (vllm-project#8790)

---------

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: sasha0552 <admin@sasha0552.org>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Kevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Alex Brooks <alex.brooks@ibm.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: sroy745 <142070531+sroy745@users.noreply.github.com>
Co-authored-by: chenqianfzh <51831990+chenqianfzh@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Andy Dai <76841985+Imss27@users.noreply.github.com>
Co-authored-by: Alexey Kondratiev(AMD) <143633163+alexeykondrat@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Daniele <36171005+dtrifiro@users.noreply.github.com>
Co-authored-by: Jiaxin Shan <seedjeffwan@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Geun, Lim <shing100@Naver.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Co-authored-by: Kuntai Du <kuntai@uchicago.edu>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Charlie Fu <charlifu@amd.com>
Co-authored-by: 盏一 <w@hidva.com>
Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com>
Co-authored-by: Amit Garg <mitgarg17495@gmail.com>
Co-authored-by: William Lin <SolitaryThinker@users.noreply.github.com>
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
Co-authored-by: saumya-saran <saumya.saran@c3.ai>
Co-authored-by: Pastel! <1627301104@qq.com>
Co-authored-by: omrishiv <327609+omrishiv@users.noreply.github.com>
Co-authored-by: zyddnys <zyddnys@outlook.com>
Co-authored-by: youkaichao <youkaichao@126.com>
Co-authored-by: rasmith <Randall.Smith@amd.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Huazhong Ji <hzji210@gmail.com>
Co-authored-by: litianjian <45817262+litianjian@users.noreply.github.com>
Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
Co-authored-by: Yuan <yuan.zhou@intel.com>
Co-authored-by: Yan Ma <yan.ma@intel.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: Yanyi Liu <wolfsonliu@163.com>
Co-authored-by: Jani Monoses <jani.monoses@gmail.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Peter Salas <peter@fixie.ai>
Co-authored-by: Hanzhi Zhou <hanzhi713@gmail.com>
Co-authored-by: Andy <37781802+aandyw@users.noreply.github.com>
Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Archit Patke <apatke@illinois.edu>
Co-authored-by: zifeitong <zifeitong@gmail.com>
Co-authored-by: sohamparikh <sohamparikh47@gmail.com>
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Sep 27, 2024
…TQMarlin (vllm-project#7701)

Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants