Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Intel] Add CPU inference backend #3634

Merged
merged 32 commits into from
Apr 2, 2024

Conversation

bigPYJ1151
Copy link
Contributor

@bigPYJ1151 bigPYJ1151 commented Mar 26, 2024

This PR adds a new CPU backend to vLLM and supports the basic model inference feature, with BF16 and FP32 dtype. FP16 support and TP support will be added in the future.

Changes to vLLM:

  • Added VLLM_TARGET_DEVICE ENV to specify backend explicitily.
  • Added CPUExecutor to isolate CPU backend with others.
  • Added TorchSDPABackend to support MHA on CPU.
  • Added _C related kernels on CPU.
  • Forwarded DeviceConfig to CacheEngine to avoid cuda hardcoded device memory allocation.
  • Added documents with install instructions.

RFC: #3654


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@bigPYJ1151 bigPYJ1151 force-pushed the cpu_pr branch 3 times, most recently from de3232c to 02ec530 Compare March 26, 2024 15:58
@WoosukKwon WoosukKwon self-assigned this Mar 26, 2024
@WoosukKwon
Copy link
Collaborator

Hi @bigPYJ1151 Thanks for updating the PR! It looks really nice.

Just for other people's understanding, could you write an RFC about the overall design, supported features, key technical decisions, and integration plan? I think this should be easy since you already wrote most of them in the previous PR. Please check out #3620 and #1866 for reference.

Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a couple of CMake questions/nits

cmake/cpu_extension.cmake Outdated Show resolved Hide resolved
cmake/cpu_extension.cmake Show resolved Hide resolved
@bigPYJ1151
Copy link
Contributor Author

@WoosukKwon Sure, please refer to #3654

@hustnn
Copy link

hustnn commented Mar 27, 2024

Is there initial performance result for cpu reference?

@zhouyuan
Copy link
Contributor

Hi @hustnn ,

In general the performance number on CPU is not as good as GPU, for both latency and throughput. However we do find there are two value proposition for vLLM w/ CPU based on our initial tests:

  • much higher throughput vs naïve/static batching solutions(TGI), this is due to the great throughput oriented design of vLLM
  • for near-offline inference cases, vLLM w/ CPU throughput performance is competitive vs. entry-level GPU, mostly due to the much larger KV cache space(vs. GPU). The cost for CPU based solution may also be lower.

thanks,
-yuan

Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall nice job keeping the complexity in check and working with recent changes. Left a few comments

Dockerfile.cpu Outdated Show resolved Hide resolved
csrc/cpu/pybind.cpp Outdated Show resolved Hide resolved

.. code-block:: console

$ VLLM_TARGET_DEVICE=cpu python setup.py install
Copy link
Sponsor Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include an example of how to enable bf16 support i.e. through enabling the VLLM_CPU_AVX512BF16 env var

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BF16 is always supported in the cpu backend. AVX512BF16 is an extension ISA. The build script will check the host CPU flag to determine whether to enable it. VLLM_CPU_AVX512BF16 is used for the cross-compilation requirement. To clarify this question, I have updated some information in the doc.

@hustnn
Copy link

hustnn commented Mar 28, 2024

In general the performance number on CPU is not as good as GPU, for both latency and throughput. However we do find there are two value proposition for vLLM w/ CPU based on our initial tests:

  • much higher throughput vs naïve/static batching solutions(TGI), this is due to the great throughput oriented design of vLLM
  • for near-offline inference cases, vLLM w/ CPU throughput performance is competitive vs. entry-level GPU, mostly due to the much larger KV cache space(vs. GPU). The cost for CPU based solution may also be lower.

@zhouyuan Thanks for your reply, it is very helpful. These 2 points match our requirement quite well. We are planning to integrate a inference operator into a olap database, we care more about throughput compared to latency since we are targeting some offline analysis scenario.

Do you have any suggestion on how should we start with some experiment? Should we wait for these MR to be merged? We want to get some initial number on the throughput and see it is acceptable or any improvement we can further do from DB's aspect.

I also found a article from intel, is it the result and method consistent with your testing?
https://medium.com/@NeuralCompressor/llm-performance-of-intel-extension-for-transformers-f7d061556176

@zhouyuan
Copy link
Contributor

In general the performance number on CPU is not as good as GPU, for both latency and throughput. However we do find there are two value proposition for vLLM w/ CPU based on our initial tests:

  • much higher throughput vs naïve/static batching solutions(TGI), this is due to the great throughput oriented design of vLLM
  • for near-offline inference cases, vLLM w/ CPU throughput performance is competitive vs. entry-level GPU, mostly due to the much larger KV cache space(vs. GPU). The cost for CPU based solution may also be lower.

@zhouyuan Thanks for your reply, it is very helpful. These 2 points match our requirement quite well. We are planning to integrate a inference operator into a olap database, we care more about throughput compared to latency since we are targeting some offline analysis scenario.

Do you have any suggestion on how should we start with some experiment? Should we wait for these MR to be merged? We want to get some initial number on the throughput and see it is acceptable or any improvement we can further do from DB's aspect.

Hi @hustnn
The dockerfile in this patch maybe a good start to check:
https://github.com/vllm-project/vllm/blob/384623538c081ed621b04c1eec107132920e5045/Dockerfile.cpu

If build successfully, the docker image should be enough to run some benchmarks via the scripts provided in vLLM:
https://github.com/vllm-project/vllm/tree/main/benchmarks

Please note you may need to set some params to do NUMA binding as this may impact the performance for vLLM w/ CPU

For real deployment, vLLM provides the several methods to expose the service endpoint:
https://docs.vllm.ai/en/latest/serving/deploying_with_docker.html
You may then connect your application to the vLLM endpoint via langchain or other soltuons.

I also found a article from intel, is it the result and method consistent with your testing? https://medium.com/@NeuralCompressor/llm-performance-of-intel-extension-for-transformers-f7d061556176

Yes, the performance is improved if using INT4 quantization from Intel extension for transformers.
Intel PyTorch extensions is also a good refence. Here's the link to the project:
https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance.html

thanks,
-yuan

@WoosukKwon
Copy link
Collaborator

@bigPYJ1151 @zhouyuan QQ: Can we use torch.compile to auto-generate the custom C++ kernels except PagedAttention? This would increase the maintainability of the code a lot. I'm wondering how torch.compile performs on Intel CPUs.

@bigPYJ1151
Copy link
Contributor Author

@WoosukKwon Agree, I think this might be a good direction to try. For these element-wise operations and normalization operations, using torch.compile would unify the front-end to Python code and use different device backends to apply optimizations and generate binary code.

TorchInductor has two IR lowering path:

  • PyTorch → TorchDynamo → TorchInductor → Triton → NVIDIA GPU
  • PyTorch → TorchDynamo → TorchInductor → OpenMP (C++) → CPU

The second path is designed for CPU, and is under active development and evolution. Here is a blog contains some examples for your reference.

We need to further check the current development status and any gaps to utilize it to vLLM.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bigPYJ1151 Thanks for the great PR! And apologies for the delay in the review.

I'm happy to see that the code is super clean and fits well into our new abstractions. I also appreciate that you provided detailed documents and instructions. My experience of building and running the CPU backend was indeed super smooth.

The only major design decision I'd like to discuss on this PR is whether we want to share ModelRunner between the GPU and CPU backends. While I do understand that sharing will save code duplication substantially, I'm a bit worried that any change in ModelRunner can break CPU backend. Also, sharing can over-complicate the development of the CPU backend since the CPU backend doesn't use CUDA graph, prefix caching, etc. That being said, for faster initial integration, we can merge the current PR and change this in the future PR. WDYT?

Dockerfile.cpu Outdated Show resolved Hide resolved
vllm/attention/backends/torch_sdpa.py Outdated Show resolved Hide resolved
vllm/attention/backends/torch_sdpa.py Show resolved Hide resolved
docs/source/getting_started/cpu-installation.rst Outdated Show resolved Hide resolved
docs/source/getting_started/cpu-installation.rst Outdated Show resolved Hide resolved
vllm/executor/cpu_executor.py Outdated Show resolved Hide resolved
vllm/worker/cache_engine.py Outdated Show resolved Hide resolved
vllm/worker/cpu_worker.py Outdated Show resolved Hide resolved
vllm/worker/cpu_worker.py Outdated Show resolved Hide resolved
logger = init_logger(__name__)


class CPUModelRunner(ModelRunner):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I understand the context of the previous PR that changed the hardcoded "cuda" to self.device. I actually thought the CPU backend would have its own model runner. And I still believe it can be a better design since it isolates the two backends from each other.

What do you think? As CUDA graph is not used for CPUs, I believe the amount of the code duplicated when defining a new model runner for CPUs would not be that large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, separating CPUModelRunner entirely will avoid all potential code breaks. It needs to move hundreds of lines codes, we will do it in the future PRs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please create an issue about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, refer to #3776

@bigPYJ1151 bigPYJ1151 force-pushed the cpu_pr branch 2 times, most recently from df6a6c1 to 153e239 Compare April 1, 2024 06:14
@bigPYJ1151
Copy link
Contributor Author

@WoosukKwon Thanks for your comments! I have fixed most of them.
For CPUModelRunner, yes, you are right, isolate it with ModelRunner will avoid potential code breaks completely. We can do it in the future to reduce the PR size.

vllm/worker/worker.py Outdated Show resolved Hide resolved
vllm/worker/cpu_worker.py Outdated Show resolved Hide resolved
vllm/executor/cpu_executor.py Outdated Show resolved Hide resolved
@@ -363,6 +364,10 @@ def add_cli_args(
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument('--cpu-kvcache-space',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Can we somehow make this an argument only for CPU backends?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this flag is too specific.

It is difficult to find a proper method using CLI arguments. I tried to change this parameter as an environment variable and highlight it using warning message when it is not set. Please look at the latest commit.

vllm/executor/cpu_executor.py Outdated Show resolved Hide resolved

- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.

- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: Can we always set this env var? When should the user set this flag off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need. This is because the AVX512BF16 is not widely supported on older CPU models.

If users want to run vLLM on the host, the build script will check the host CPU flag to determine the compile flag.

If users want to build vLLM as a package on the host without AVX512BF16, and run vLLM on other machines with AVX512BF16, this env var should be enabled for this cross-compilation.

vllm/attention/backends/torch_sdpa.py Show resolved Hide resolved
logger = init_logger(__name__)


class CPUModelRunner(ModelRunner):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please create an issue about this?

vllm/worker/cpu_worker.py Outdated Show resolved Hide resolved
@bigPYJ1151
Copy link
Contributor Author

bigPYJ1151 commented Apr 1, 2024

Hi @WoosukKwon Thanks for your further comments. I have fixed them all, please check, thanks.

bigPYJ1151 and others added 3 commits April 1, 2024 11:51
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Yuan Zhou <yuan.zhou@intel.com>
@hustnn
Copy link

hustnn commented Apr 2, 2024

hi @zhouyuan @bigPYJ1151 , Thanks for your great work! Do you have plan to public some results on cpu inference with vllm?

@zhouyuan
Copy link
Contributor

zhouyuan commented Apr 2, 2024

@bigPYJ1151 @zhouyuan @jikunshang LGTM! Huge thanks for the great work! Very excited to finally have the CPU backend.

Please continue to work on investigating torch.compile and separating out ModelRunner for CPUs. These will increase the maintainability of the backend. Thanks again for the great work!

@WoosukKwon thanks for the detailed review and much appreciated on your guidance! Sure, will follow up on the refactoring, features and performance optimizations.

thanks,
-yuan

@zhouyuan
Copy link
Contributor

zhouyuan commented Apr 2, 2024

hi @zhouyuan @bigPYJ1151 , Thanks for your great work! Do you have plan to public some results on cpu inference with vllm?

Hi @hustnn,

We are now reviewing/seeking approval for the performance data, should be able to publish the perf data soon if everything goes well.

thanks,
-yuan

@hustnn
Copy link

hustnn commented Apr 2, 2024

hi @zhouyuan @bigPYJ1151 , Thanks for your great work! Do you have plan to public some results on cpu inference with vllm?

Hi @hustnn,

We are now reviewing/seeking approval for the performance data, should be able to publish the perf data soon if everything goes well.

thanks, -yuan

@zhouyuan Thanks and look forward to the performance report.

njhill added a commit to njhill/vllm that referenced this pull request Apr 2, 2024
Using importlib version() function doesn't work in the docker image.

Introduced in vllm-project#3634

Equivalent fix to vllm-project#3735
@abhilash1910
Copy link
Contributor

abhilash1910 commented Apr 3, 2024

Great work @bigPYJ1151 @zhouyuan @jikunshang !
Looking forward to #3814 .
Thanks @WoosukKwon for awesome collaboration!.

@markluofd
Copy link

Will intel cpu backend be equipped with AsyncLLMEngine? api_server is currently using AsyncLLMEngine

@bigPYJ1151
Copy link
Contributor Author

Hi @markluofd the online inference of the CPU backend is still under tunning, we will enable it when it is ready.

@markluofd
Copy link

@bigPYJ1151 To imitate the code of https://github.com/vllm-project/vllm/pull/3814/files#diff-d1c5ec4ddd588e3b7cac13bde85a98ac5b20686dc16b9da3b1c324467c3be2b5 url, I added CPUExecutorAsync in cpu_executor.py. AsyncLLMEngine and api_server can work normally. I want to know if this method will affect the inference performance of the CPU.

@bigPYJ1151
Copy link
Contributor Author

@markluofd Yes, the performance may have some regression. Because the CPU inference thread pool(OpenMP), HTTP service thread pool, and tokenizer threads will scramble CPU cores.
We plan to isolate the inference thread pool from others to avoid this problem.

@markluofd
Copy link

@bigPYJ1151 ok, thanks!There is another question , the introduction says that only bf16 and fp32 are supported. I found that if the dtype is fp16, the CPU backend can also execute normally. I want to know whether it uses the fp16 or fp32 kernel ?(My machine is a 3rd generation cpu with no avx512_bf16 instruction)

@bigPYJ1151
Copy link
Contributor Author

@markluofd FP16 will be cast to BF16 right now. BF16 is always supported even if there is no avx512_bf16 ISA.
Pure FP16 support will be added soon, might be at the end of the month.

@markluofd
Copy link

@bigPYJ1151 I get it, thank you! I need to find a 4th generation CPU and test the performance of bf16. I also found that the overall utilization rate of the current 96-core CPU is about 16% (30 concurrent requests). I hope that subsequent features will be incorporated to bring higher CPU utilization.

@ProExpertProg
Copy link
Contributor

@bigPYJ1151 @WoosukKwon I did not want to slow down the merging of this PR but I was wondering if there's a plan to decouple the CPU backend via an additional level of abstraction to allow for choosing the backend at runtime (or startup time) as opposed to the build time. I'm guessing all calls that go to the _c pytorch extension would have to go through the worker/executor but it doesn't seem like it would be too complex?

@zhouyuan
Copy link
Contributor

zhouyuan commented Apr 4, 2024

@bigPYJ1151 I get it, thank you! I need to find a 4th generation CPU and test the performance of bf16. I also found that the overall utilization rate of the current 96-core CPU is about 16% (30 concurrent requests). I hope that subsequent features will be incorporated to bring higher CPU utilization.

Hi @markluofd
Thanks for reporting, besides the threading pool conflict issue posted by @bigPYJ1151, in CPU based env you may need to do several tunings to get better performance, especially on NUMA node access and OpenMP threads.

OMP_NUM_THREADS=32 numactl --physcpubind=0-31 --membind=0 python benchmark.py

In my 4th gen Xeon env, this tuning can bring ~30% perf improvement. I suppose this can also help to improve the resource utilization in your tests.

Here are some tuning recipe for CPU based env:
https://pytorch.org/tutorials/intermediate/torchserve_with_ipex.html
https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html

thanks,
-yuan

@bigPYJ1151
Copy link
Contributor Author

Hi @ProExpertProg It is feasible to load different backends dylib at runtime. vLLM has multple backends with different dependencies and configurations, so it might be a lot of works to support the runtime binding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants