Skip to content

Commit

Permalink
[Core] Use flashinfer sampling kernel when available (vllm-project#7137)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
peng1999 and mgoin committed Aug 19, 2024
1 parent ff7ec82 commit f710fb5
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 28 deletions.
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ steps:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
- tests/samplers
command: pytest -v -s samplers
commands:
- pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers

- label: LogitsProcessor Test # 5min
mirror_hardwares: [amd]
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE ####################


Expand Down
37 changes: 36 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from transformers import GenerationConfig, GenerationMixin

import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -634,7 +635,10 @@ def mock_sample(probs, *args, **kwargs):
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)

with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
# top-k and top-p is only calculated when flashinfer kernel is not available
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
patch("vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", None):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

assert sample_probs is not None
Expand All @@ -645,6 +649,37 @@ def mock_sample(probs, *args, **kwargs):
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
pytest.skip("Flashinfer sampler is disabled")

set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)

def failing_flashinfer_sampling(*_args, **_kwargs):
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

with patch(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

assert sampler_output == fallback_sampler_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):

Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
Expand Down Expand Up @@ -256,6 +257,10 @@ def get_default_config_root():
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),

# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
Expand Down
110 changes: 85 additions & 25 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from importlib.util import find_spec
from math import inf
from typing import Dict, List, Optional, Tuple

Expand All @@ -11,6 +13,7 @@
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton

import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
Expand All @@ -19,6 +22,16 @@
PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceOutput)

if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)

# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None

# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]

Expand Down Expand Up @@ -123,7 +136,7 @@ def forward(
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

Expand Down Expand Up @@ -476,32 +489,65 @@ def _multinomial(
seq_groups: Optional[List[SequenceGroupToSample]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
probs = probs.repeat_interleave(num_samples, dim=0)
q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(
generator=seq_group.generator)
sample_idx = next_sample_idx
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
q[sample_idx:sample_idx +
stride].exponential_(generator=seq_group.generator)
sample_idx += stride
return probs.div_(q).argmax(dim=1).view(-1, num_samples)


def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
top_ks = top_ks.repeat_interleave(num_samples)
top_ps = top_ps.repeat_interleave(num_samples)
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if seq_groups is None:
uniform_samples.uniform_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
uniform_samples[:, sample_idx:sample_idx +
stride].uniform_(generator=seq_group.generator)
sample_idx += stride
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
probs,
uniform_samples,
top_ks,
top_ps,
)
if not success.all():
warnings.warn("FlashInfer rejection sampling failed, fallback.",
stacklevel=1)
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0])
return batch_next_token_ids.view(-1, num_samples)


def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
Expand Down Expand Up @@ -564,18 +610,28 @@ def _sample_with_torch(
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
}

multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices], max_best_of_in_batch,
**seeded_args)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)

if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_best_of_in_batch,
seq_groups=seq_groups_arg)

if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type]
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)

elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
Expand Down Expand Up @@ -693,9 +749,12 @@ def _sample_with_triton_kernel(


def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
Expand All @@ -713,6 +772,7 @@ def _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
Expand Down

0 comments on commit f710fb5

Please sign in to comment.