Skip to content

Commit

Permalink
[OpenVINO] Enable GPU support for OpenVINO vLLM backend (#8192)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn authored Oct 2, 2024
1 parent afb050b commit f58d4fc
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 107 deletions.
35 changes: 29 additions & 6 deletions docs/source/getting_started/openvino-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Installation with OpenVINO
==========================

vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete Intel® GPUs (`the list of supported GPUs <https://docs.openvino.ai/2024/about-openvino/release-notes-openvino/system-requirements.html#gpu>`_). OpenVINO vLLM backend supports the following advanced vLLM features:

- Prefix caching (``--enable-prefix-caching``)
- Chunked prefill (``--enable-chunked-prefill``)
Expand Down Expand Up @@ -53,34 +53,57 @@ Install from source
$ pip install --upgrade pip
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Finally, install vLLM with OpenVINO backend:
- Finally, install vLLM with OpenVINO backend:

.. code-block:: console
$ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
- [Optional] To use vLLM OpenVINO backend with a GPU device, ensure your system is properly set up. Follow the instructions provided here: `https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html <https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html>`_.

.. _openvino_backend_performance_tips:

Performance tips
----------------

vLLM OpenVINO backend uses the following environment variables to control behavior:
vLLM OpenVINO backend environment variables
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

- ``VLLM_OPENVINO_DEVICE`` to specify which device utilize for the inference. If there are multiple GPUs in the system, additional indexes can be used to choose the proper one (e.g, ``VLLM_OPENVINO_DEVICE=GPU.1``). If the value is not specified, CPU device is used by default.

- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`

CPU performance tips
~~~~~~~~~~~~~~~~~~~~

CPU uses the following environment variables to control behavior:

- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.

- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`

To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)

OpenVINO best known configuration is:
OpenVINO best known configuration for CPU is:

.. code-block:: console
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
GPU performance tips
~~~~~~~~~~~~~~~~~~~~
GPU device implements the logic for automatic detection of available GPU memory and, by default, tries to reserve as much memory as possible for the KV cache (taking into account ``gpu_memory_utilization`` option). However, this behavior can be overridden by explicitly specifying the desired amount of memory for the KV cache using ``VLLM_OPENVINO_KVCACHE_SPACE`` environment variable (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=8`` means 8 GB space for KV cache).

Currently, the best performance using GPU can be achieved with the default vLLM execution parameters for models with quantized weights (8 and 4-bit integer data types are supported) and `preemption-mode=swap`.

OpenVINO best known configuration for GPU is:

.. code-block:: console
$ VLLM_OPENVINO_DEVICE=GPU VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json
.. _openvino_backend_limitations:

Limitations
Expand Down
5 changes: 3 additions & 2 deletions requirements-openvino.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

# OpenVINO dependencies
torch >= 2.1.2
openvino ~= 2024.3.0
optimum-intel[openvino] >= 1.18.2
openvino ~= 2024.4.0
openvino-tokenizers[transformers] ~= 2024.4.0
optimum-intel[openvino] >= 1.19.0
40 changes: 32 additions & 8 deletions vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
from vllm.attention.backends.utils import CommonAttentionState


def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
src_offset: int, dst_offset: int) -> None:

def create_roi_tensor(
tensor: ov.Tensor,
block_number: int,
) -> ov.Tensor:
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
roi_end = ov.runtime.Coordinate(tensor.get_shape())

roi_begin[0] = block_number
roi_end[0] = block_number + 1

if isinstance(tensor, ov.Tensor):
return ov.Tensor(tensor, roi_begin, roi_end)
else:
return ov.RemoteTensor(tensor, roi_begin, roi_end)

src_roi_tensor = \
create_roi_tensor(src_tensor, src_offset)
dst_roi_tensor = \
create_roi_tensor(dst_tensor, dst_offset)
src_roi_tensor.copy_to(dst_roi_tensor)


class OpenVINOAttentionBackend(AttentionBackend):

@staticmethod
Expand Down Expand Up @@ -44,13 +69,12 @@ def get_kv_cache_shape(

@staticmethod
def swap_blocks(
src_kv_cache: ov.Tensor,
dst_kv_cache: ov.Tensor,
src_to_dst: torch.Tensor,
src_tensor: ov.Tensor,
dst_tensor: ov.Tensor,
src_to_dists: List[Tuple[int, int]],
) -> None:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise NotImplementedError
for src, dst in src_to_dists:
copy_cache_block(src_tensor, dst_tensor, src, dst)

@staticmethod
def copy_blocks(
Expand All @@ -59,8 +83,8 @@ def copy_blocks(
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
copy_cache_block(key_cache, key_cache, src, dst)
copy_cache_block(value_cache, value_cache, src, dst)


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_DEVICE: str = "CPU"
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
Expand Down Expand Up @@ -302,6 +303,11 @@ def get_default_config_root():
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),

# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE":
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),

# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":
Expand Down
73 changes: 53 additions & 20 deletions vllm/executor/openvino_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,28 @@
logger = init_logger(__name__)


def is_openvino_cpu() -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE


def is_openvino_gpu() -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE


class OpenVINOExecutor(ExecutorBase):

uses_ray: bool = False

def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
assert is_openvino_cpu() or is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"

self.ov_core = ov.Core()
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.cache_config = _verify_and_get_cache_config(
self.ov_core, self.cache_config)

# Instantiate the worker and load the model to CPU.
self._init_worker()
Expand All @@ -40,6 +53,7 @@ def _init_worker(self):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker(
ov_core=self.ov_core,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
Expand Down Expand Up @@ -68,10 +82,13 @@ def initialize_cache(self, num_gpu_blocks: int,
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
# NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
# is located on CPU memory but is referred as `gpu block`.
# Because we want to reuse the existing block management procedure.
device_blocks = num_gpu_blocks
swap_blocks = num_cpu_blocks
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

def execute_model(
Expand Down Expand Up @@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return config


def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
def _verify_and_get_cache_config(ov_core: ov.Core,
config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
logger.info("KV cache type is overried to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
if not is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used.")
config.cache_dtype = ov.Type.f16
else:
logger.info("KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
else:
core = ov.Core()
inference_precision = core.get_property("CPU",
hints.inference_precision)
if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
if is_openvino_cpu():
ov_device = envs.VLLM_OPENVINO_DEVICE
inference_precision = ov_core.get_property(
ov_device, hints.inference_precision)
if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
else:
config.cache_dtype = ov.Type.f16
else:
config.cache_dtype = ov.Type.f16

if config.block_size != 32:
logger.info(
f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 32
if is_openvino_cpu():
if config.block_size != 32:
logger.info(
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 32
else:
if config.block_size != 16:
logger.info(
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 16

kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0:
if kv_cache_space == 0 and is_openvino_cpu():
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
Expand Down
26 changes: 10 additions & 16 deletions vllm/model_executor/model_loader/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states)
Expand Down Expand Up @@ -51,25 +52,15 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
shape = parameter.get_partial_shape()
# use real block size if available, just a placeholder
# to provide the expected rank
x_size = 1
num_blocks = ov.Dimension()
block_size = ov.Dimension()
head_size = ov.Dimension()
# TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
# pass more parameters to this function to set more static dimensions
if input_name.startswith("key_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [
num_blocks,
shape[1],
shape[2].get_length() //
x_size if shape[2].is_static else ov.Dimension(),
block_size,
x_size,
]
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
elif input_name.startswith("value_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
gpu_shape = [num_blocks, shape[1], block_size, shape[2]]
else:
continue
parameter.set_partial_shape(
Expand Down Expand Up @@ -108,6 +99,7 @@ class OpenVINOCasualLM(nn.Module):

def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type,
Expand Down Expand Up @@ -141,12 +133,12 @@ def __init__(
trust_remote_code=model_config.trust_remote_code,
)

ov_device = envs.VLLM_OPENVINO_DEVICE
paged_attention_transformation(pt_model.model)
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
device_config.device.type == "cpu")
is_openvino_cpu())

core = ov.Core()
ov_compiled = core.compile_model(pt_model.model, "CPU")
ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
self.ov_request = ov_compiled.create_infer_request()

def forward(
Expand Down Expand Up @@ -199,11 +191,13 @@ def get_model(
**kwargs,
) -> torch.nn.Module:
lora_config = kwargs.get("lora_config", None)
ov_core = kwargs.get("ov_core")
if lora_config:
raise ValueError(
"OpenVINO modeling does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")

return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)
return OpenVINOCasualLM(ov_core, model_config, device_config,
kv_cache_dtype)
11 changes: 6 additions & 5 deletions vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class OpenVINOModelRunner:

def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
Expand All @@ -55,6 +56,7 @@ def __init__(
*args,
**kwargs,
):
self.ov_core = ov_core
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
Expand Down Expand Up @@ -89,11 +91,10 @@ def __init__(
self.model: nn.Module # Set after init_Model

def load_model(self) -> None:
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
kv_cache_dtype=self.kv_cache_dtype,
)
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core)

def _prepare_model_input(
self,
Expand Down
Loading

0 comments on commit f58d4fc

Please sign in to comment.