Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Core] Add AWQ support to the Marlin kernel #6612

Merged

Conversation

alexm-neuralmagic
Copy link
Collaborator

@alexm-neuralmagic alexm-neuralmagic commented Jul 21, 2024

This PR adds end-to-end support for AWQ quantization inside the Marlin kernel.

  1. Modifies the existing gptq_marlin.cu to support zero-points
  2. Adds a new AWQ on-the-fly repack kernel that converts from AWQ format to Marlin format
  3. Add a new linear layer: awq_marlin.py

Here are initial performance results of awq_marlin (this PR) vs awq (on the vllm main) for Llama3-70b AWQ model on 2xA100 GPUs and Llama3-8B AWQ on 1xA100 GPU.

Llama3-70B AWQ on 2xA100 GPUs with prompt = 512 and decode = 256
image

Llama3-7B AWQ on 1xA100 GPUs with prompt = 1024 and decode = 512
image

TODOs (may be done after this PR lands):

  1. Add an end-to-end correctness test for AWQ models like we have for GPTQ models for both float16 and bfloat16
  2. Anything else?

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@alexm-neuralmagic
Copy link
Collaborator Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 21, 2024
Comment on lines 260 to 261
CALL_IF(4)
CALL_IF(4)
CALL_IF(8)
CALL_IF(8)
Copy link
Sponsor Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these doubled?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, removed duplicates

namespace marlin {

template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
Copy link
Sponsor Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different from GPTQ? It looks similar to me at a glance

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internal unpacking is different: AWQ packs over columns, while GPTQ over rows, and also AWQ performs the interleaving of groups of 8 (for 4-bit) or groups of 4 (for 8-bit) to be compatible to the de-quantization PTX assembly.

@alexm-neuralmagic
Copy link
Collaborator Author

@mgoin @robertgshaw2-neuralmagic added bfloat16 support

@casper-hansen
Copy link
Contributor

As I told @robertgshaw2-neuralmagic on Discord, this kind of speedup warrants a backport to AutoAWQ.

Does it make sense to natively pack the weights in AutoAWQ for the Marlin format and if so, do you have any reference code for this now that zero points are supported?

@linpan
Copy link

linpan commented Jul 26, 2024

—quantization awq_marlin not work.

@alexm-neuralmagic
Copy link
Collaborator Author

@linpan should work on latest main.

@zhyncs
Copy link
Contributor

zhyncs commented Aug 1, 2024

Hi @alexm-neuralmagic Nice work! Alex. The performance of Marlin AWQ amazed me, excellent work.

@zhyncs
Copy link
Contributor

zhyncs commented Aug 1, 2024

Hi @alexm-neuralmagic I conducted a benchmark 5k ShareGPT on LMDeploy and SGLang, and their AWQ performance was surprisingly close, which is incredible!

============ Serving Benchmark Result ============
Backend:                                 lmdeploy
Traffic request rate:                    inf
Successful requests:                     5000
Benchmark duration (s):                  271.50
Total input tokens:                      1104092
Total generated tokens:                  1004593
Total generated tokens (retokenized):    1009505
Request throughput (req/s):              18.42
Input token throughput (tok/s):          4066.59
Output token throughput (tok/s):         3700.11
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   137295.92
Median E2E Latency (ms):                 137663.99
---------------Time to First Token----------------
Mean TTFT (ms):                          131020.95
Median TTFT (ms):                        131106.98
P99 TTFT (ms):                           258403.29
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          33.09
Median TPOT (ms):                        32.58
P99 TPOT (ms):                           77.93
---------------Inter-token Latency----------------
Mean ITL (ms):                           1597.30
Median ITL (ms):                         53.95
P99 ITL (ms):                            37236.28
==================================================

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     5000
Benchmark duration (s):                  276.45
Total input tokens:                      1091912
Total generated tokens:                  1010809
Total generated tokens (retokenized):    1006137
Request throughput (req/s):              18.09
Input token throughput (tok/s):          3949.72
Output token throughput (tok/s):         3656.35
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   149898.05
Median E2E Latency (ms):                 154363.75
---------------Time to First Token----------------
Mean TTFT (ms):                          91768.39
Median TTFT (ms):                        77725.08
P99 TTFT (ms):                           206159.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          431.19
Median TPOT (ms):                        333.59
P99 TPOT (ms):                           2713.51
---------------Inter-token Latency----------------
Mean ITL (ms):                           762.29
Median ITL (ms):                         227.06
P99 ITL (ms):                            1026.13
==================================================

@alexm-neuralmagic
Copy link
Collaborator Author

@zhyncs thanks for doing these benchmarks! I would also expect AWQ to be inherently faster than GPTQ because AWQ has no activation order, especially for multi-gpu runs.

@zhyncs
Copy link
Contributor

zhyncs commented Aug 2, 2024

Hi @alexm-neuralmagic Thank you for your reply. I conducted an eval of gsm8k using lm_eval on Llama 3 8B Instruct and the AWQ model, and found that there was a significant decrease in accuracy. Is this expected? My replication steps and results are as follows:

python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct
python -m vllm.entrypoints.openai.api_server --model casperhansen/llama-3-8b-instruct-awq

lm_eval --model local-completions --tasks gsm8k --num_fewshot 8 --model_args model=meta-llama/Meta-Llama-3-8B-Instruct,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=True
lm_eval --model local-completions --tasks gsm8k --num_fewshot 8 --model_args model=casperhansen/llama-3-8b-instruct-awq,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=True
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.7506|±  |0.0119|
|     |       |strict-match    |     8|exact_match|↑  |0.7498|±  |0.0119|

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.7043|±  |0.0126|
|     |       |strict-match    |     8|exact_match|↑  |0.7051|±  |0.0126|

@alexm-neuralmagic
Copy link
Collaborator Author

Recently we also merged PR to increase the precision of marlin (uses FP32 full precision global reductions): #6795

I would compare awq vs awq_marlin to see apples-to-apples comparison (and not directly to fp16 to avoid any quantization related errors)

@zhyncs
Copy link
Contributor

zhyncs commented Aug 2, 2024

Hi @alexm-neuralmagic OK. I saw that the PR you mentioned was merged after the latest release. I will try again after the new version is released. Thank you.

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

Hi @alexm-neuralmagic I tested vLLM 0.5.4 and noticed the accuracy has worsened. Is this expected? Thanks.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.6998|±  |0.0126|
|     |       |strict-match    |     8|exact_match|↑  |0.6998|±  |0.0126|

@alexm-neuralmagic
Copy link
Collaborator Author

@zhyncs thanks for checking. Could you please provide reproduction instructions

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

@zhyncs thanks for checking. Could you please provide reproduction instructions

ref #6612 (comment)

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

robertgshaw2-neuralmagic commented Aug 6, 2024

@zhyncs - so the drop is from 0.7043 to 0.6998?

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

Yes, the previous version already had accuracy issues, which this version was supposed to fix, but it ended up being even worse.

@alexm-neuralmagic
Copy link
Collaborator Author

@zhyncs did you try --quantization awq (to force the original awq kernel)

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

This has nothing to do with that, let's just compare it directly to fp16. The current drop in accuracy is unacceptable, it can't be used in online business at all. By the way, LMDeploy is much better than this, you can test it yourself. Thanks.

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

robertgshaw2-neuralmagic commented Aug 6, 2024

The AWQ model is 4 bit quantized, so you should not expect to see the same scores between fp16 and int4

What score does LMDeploy achieve for the AWQ model with GSM?

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

expect to see the same scores

Your understanding has a big problem. From beginning to end, there was never any mention or expectation of the same score. I don't know why you have this strange misunderstanding. What is being said here is that the current implementation of AWQ Marlin drops points too severely. It's not that it can't drop points, but at least you must ensure that the accuracy after dropping points is usable. Right now, this situation belongs to an unusable state.

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

robertgshaw2-neuralmagic commented Aug 6, 2024

I get the same accuracy scores when using awq_marlin vs using awq:

Client launch command:

lm_eval --model local-completions --tasks gsm8k --num_fewshot 8 --model_args model=casperhansen/llama-3-8b-instruct-awq,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=True

awq:

launch command (forces using the old awq kernels):

python -m vllm.entrypoints.openai.api_server --model casperhansen/llama-3-8b-instruct-awq --quantization awq

scores:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match||0.7005|±  |0.0126|
|     |       |strict-match    |     8|exact_match||0.6998|±  |0.0126|

awq_marlin:

launch command (we use the awq_marlin kernels by default):

python -m vllm.entrypoints.openai.api_server --model casperhansen/llama-3-8b-instruct-awq

scores:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match||0.6998|±  |0.0126|
|     |       |strict-match    |     8|exact_match||0.6990|±  |0.0126|

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

I fully understand your comparison method and your results. Perhaps we should focus on how to improve the current AWQ's accuracy, what do you think? @robertgshaw2-neuralmagic @alexm-neuralmagic

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

robertgshaw2-neuralmagic commented Aug 6, 2024

I fully understand your comparison method and your results. Perhaps we should focus on how to improve the current AWQ's accuracy, what do you think? @robertgshaw2-neuralmagic @alexm-neuralmagic

Yes - the source of the drop in accuracy here seems to be the quality of the model, not the correctness of the marlin kernel. Neural Magic did not create this model and I do not know anything about how it was created. You can feel free to try to improve the accuracy as you see fit. Neural Magic provides quantized checkpoints on our Hugging Face profile that are compatible with vLLM with replication instructions for creation and evaluation of the models

The scope of this work is simply to run any AWQ model as fast as possible with accuracy scores the match the baseline implementations within numerical precision errors, which this PR accomplishes.

@mgoin
Copy link
Sponsor Collaborator

mgoin commented Aug 6, 2024

@zhyncs I think you are confusing "AWQ, the quantization algorithm" versus "AWQ, the inference kernel".
In vLLM we only have inference kernels (AWQ and AWQ Marlin are our current choices). These require models to be pre-quantized using the AWQ quantization algorithm, most notably AutoAWQ is the library used for that.
Any inference engine that supports running AWQ models should give the same ~70% GSM8k accuracy reported in this thread, because this is a product of the quantized checkpoint (i.e. casperhansen/llama-3-8b-instruct-awq) rather than any inference code changing the model weights. This is why we say it isn't an issue of vLLM, but how the model was created (aka AutoAWQ with a specific config and calibration set). As Rob said, you said LMDeploy gives better AWQ results so please share what it gives when running casperhansen/llama-3-8b-instruct-awq.

If you want to reduce the impact of "AWQ, the quantization algorithm" you can produce your own quantized checkpoint using AutoAWQ with more conservative parameters. Looking at the quantization config for that checkpoint, you can see it uses "group_size": 128 which you could decrease to "group_size": 64 or even "group_size": 32 when quantizing to improve accuracy preservation.

@casper-hansen
Copy link
Contributor

This drop in accuracy after quantization looks normal to me. The standard calibration dataset used was not math-related either.

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

robertgshaw2-neuralmagic commented Aug 6, 2024

Also, running the model through huggingface, I get the same scores as we get in vllm:

  • launch
lm_eval --model hf --model_args pretrained=casperhansen/llama-3-8b-instruct-awq --tasks gsm8k --num_fewshot 8 --batch_size 16
  • scores:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match||0.6967|±  |0.0127|
|     |       |strict-match    |     8|exact_match||0.6975|±  |0.0127|

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

I think you are confusing "AWQ, the quantization algorithm" versus "AWQ, the inference kernel".

@mgoin You misunderstood, I was saying here that the precision of AWQ in LMDeploy is better, and did not say it ran the casperhansen/llama-3-8b-instruct-awq model. The quantization in LMDeploy is generated through lmdeploy lite auto_awq, using a group size of 128, without using anything from AutoAWQ. The gsm8k eval results previously measured did not show as much reduction in precision percentage as the test results here. You can verify it yourselves.

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

Also, running the model through huggingface, I get the same scores as we get in vllm:

  • launch
lm_eval --model hf --model_args pretrained=casperhansen/llama-3-8b-instruct-awq --tasks gsm8k --num_fewshot 8 --batch_size 16
  • scores:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match||0.6967|±  |0.0127|
|     |       |strict-match    |     8|exact_match||0.6975|±  |0.0127|

@robertgshaw2-neuralmagic Thank you for providing this reference.

@mgoin
Copy link
Sponsor Collaborator

mgoin commented Aug 6, 2024

Okay @zhyncs thank you for clarifying that you are talking about a separate checkpoint than what you measured here in vLLM and that LMDeploy quantizes models itself.
I think it is clear that we don't want to produce AWQ quantized models in vLLM because we would like to leave that up to pre-deployment libraries like AutoAWQ. Maybe for your issue here you should open an issue in AutoAWQ since you would like to see improved accuracy out of checkpoints produced from there.

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

make sense

@zhyncs
Copy link
Contributor

zhyncs commented Aug 6, 2024

Also, running the model through huggingface, I get the same scores as we get in vllm:

  • launch
lm_eval --model hf --model_args pretrained=casperhansen/llama-3-8b-instruct-awq --tasks gsm8k --num_fewshot 8 --batch_size 16
  • scores:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match||0.6967|±  |0.0127|
|     |       |strict-match    |     8|exact_match||0.6975|±  |0.0127|

Hi @robertgshaw2-neuralmagic After thinking about it, the eval results with hf for the AWQ checkpoint don't clarify anything since the auto_awq kernel is used during lm_eval as well. Therefore, in theory, there should be no difference from AWQ's original implementation in vLLM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants