Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Logprobs support in Multi-step #7652

Merged
merged 49 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f97e0ae
added example
afeldman-nm Aug 21, 2024
f969241
wip:
afeldman-nm Aug 21, 2024
642d31b
first working attempt at logprobs
afeldman-nm Aug 21, 2024
a0ca262
merge; format
afeldman-nm Aug 21, 2024
ed97288
passing test; dataclass
afeldman-nm Aug 21, 2024
861e1b9
refactoring
afeldman-nm Aug 21, 2024
8bc0765
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
a34d1ac
refactoring
afeldman-nm Aug 21, 2024
4cda5c0
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 21, 2024
ac8a39a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
1284327
removing example
afeldman-nm Aug 21, 2024
a6c1207
removed example from build pipeline
afeldman-nm Aug 21, 2024
fe42995
fixed one docstring; embedded NUM_LOGPROBS
afeldman-nm Aug 21, 2024
9fb5bbe
test refactor
afeldman-nm Aug 21, 2024
046a8b1
incremental refactors
afeldman-nm Aug 21, 2024
fa86efd
remove unnecessary conftest change
afeldman-nm Aug 21, 2024
1c0ffb6
Update vllm/model_executor/layers/sampler.py
afeldman-nm Aug 21, 2024
3babadb
refactor
afeldman-nm Aug 21, 2024
f502029
Merge branch 'afeldman-nm/logprobs' of https://github.com/neuralmagic…
afeldman-nm Aug 21, 2024
1875b37
test_multi_step comment
afeldman-nm Aug 21, 2024
3760a95
utils function docstrings
afeldman-nm Aug 21, 2024
d43308c
docstring refactors
afeldman-nm Aug 21, 2024
54db498
merge
afeldman-nm Aug 21, 2024
dfbbaf0
passing tests & formatted
afeldman-nm Aug 21, 2024
5eebfca
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
5e23d9a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 22, 2024
717efa3
merge; format
afeldman-nm Aug 22, 2024
e0d59ce
removed incorrect SamplerOutput imports
afeldman-nm Aug 22, 2024
102fd92
formatting
afeldman-nm Aug 22, 2024
948f4ef
Update tests/multi_step/test_correctness.py
afeldman-nm Aug 22, 2024
6e6711f
fixed comment
afeldman-nm Aug 22, 2024
f61163e
merge; format
afeldman-nm Aug 23, 2024
1cc93dd
rename
afeldman-nm Aug 23, 2024
4995204
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 23, 2024
da5826b
test modification
afeldman-nm Aug 26, 2024
d4fb430
merge; format
afeldman-nm Aug 26, 2024
b6752e0
merge
afeldman-nm Aug 27, 2024
1e42656
formatting
afeldman-nm Aug 27, 2024
cd0fdf9
disabled logprobs pythonization when logprobs are disabled
afeldman-nm Aug 27, 2024
3fecbc4
wip
afeldman-nm Aug 27, 2024
67bd035
skip logprobs processing entirely when logprobs are not enabled; form…
afeldman-nm Aug 27, 2024
419659d
multi-step output processing; formatting
afeldman-nm Aug 27, 2024
55eaab9
wip
afeldman-nm Aug 27, 2024
bae1fb9
small fixes
afeldman-nm Aug 27, 2024
fbb75b7
reverting to no prompt-logprobs support; merged in main
afeldman-nm Aug 28, 2024
63c5582
timeout increase
afeldman-nm Aug 28, 2024
8191571
refactoring
afeldman-nm Aug 28, 2024
9a708f8
Merge branch 'main' into logprobs_no_prompt
afeldman-nm Aug 28, 2024
e54606d
upstream merge
afeldman-nm Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ steps:
- python3 offline_inference_vision_language.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_inference_multi_step.py

- label: Models Test # 1hr10min
source_file_dependencies:
Expand Down
35 changes: 35 additions & 0 deletions examples/offline_inference_multi_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
'''
Example of setting up LLM with multi-step enabled.
In actuality, async engine would be a more sensible choice
from a performance perspective. However this example is useful
for demonstration & debugging of multi-step code.
'''

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="JackFram/llama-160m",
swap_space=16,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
num_scheduler_steps=8,
use_v2_block_manager=True,
enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
169 changes: 159 additions & 10 deletions tests/multi_step/test_correctness.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Test the AsyncLLMEngine with multi-step-decoding

from typing import List
from typing import Dict, List, Optional

import pytest
from openai.types.completion import Completion

from ..utils import RemoteOpenAIServer

Expand All @@ -22,9 +23,29 @@
"16",
]

NUM_LOGPROBS = [None, 5] # `logprobs` argument to OpenAI completions API
afeldman-nm marked this conversation as resolved.
Show resolved Hide resolved

async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

async def completions_with_server_args(
afeldman-nm marked this conversation as resolved.
Show resolved Hide resolved
prompts: List[str],
model_name: str,
server_cli_args: List[str],
num_logprobs: Optional[int],
) -> Completion:
'''
Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.

Arguments:

* prompts: test prompts
* model_name: model to spin up on the vLLM server
* server_cli_args: CLI args for starting the server

Returns:

* OpenAI Completion instance
'''

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
Expand All @@ -33,12 +54,103 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
max_tokens=5,
logprobs=num_logprobs)
assert outputs is not None

return outputs


def get_text_generations(completions: Completion):
'''Obtain generated tokens'''
return [x.text for x in completions.choices]


'''
Logprobs values are extracted as List[Optional[List[Dict[str,float]]]], i.e.:
* For each :class:`SequenceGroup`...
* ...if the completions API was invoked with a non-`None` `logprobs` argument:
* ...for each token offset in a sequence...
* ...store a mapping from str(token) -> logprob
* ...else, if the completions API was invoked with `logprobs=None`:
* ...store None
'''
LogprobType = List[Optional[List[Dict[str, float]]]]


def get_logprob_generations(completions: Completion) -> LogprobType:
'''Obtain top-rank logprobs for each token in each :class:`SequenceGroup`'''
return [(None if x.logprobs is None else x.logprobs.top_logprobs)
for x in completions.choices]


def assert_all_close_logprobs(
ref_logprobs: LogprobType,
test_logprobs: LogprobType,
atol: float = 1e-3,
rtol: float = 1e-3,
) -> None:
'''
Asserts that logprobs produced by the vLLM engine instance under test
are very close to a set of ground-truth reference values.

If the completions API was invoked with a non-`None` `logprobs` argument,
then each individual reference logprob must be close to the test logprob,
according to the formula:

assert abs(tok_top_test_logprob -
tok_top_ref_logprob) <= (atol +
rtol * abs(
tok_top_ref_logprob))

Else, if the completions API was invoked with `logprobs=None`, then
both the reference & test log probs should be List[None].

Arguments:

* ref_logprobs: ground-truth logprobs
* test_logprobs: logprobs produced by vLLM engine under test
* atol: absolute mismatch tolerance when comparing single logprobs
* rtol: relative mismatch tolerance when comparing single logprobs
'''

assert len(ref_logprobs) == len(test_logprobs), (
"Reference & test logprob SequenceGroup counts must match.")

if ref_logprobs[0] is None:
# It is expected that if one :class:`SequenceGroup` has
# `None` logprobs, then all :class:`SequenceGroup`s
# in the reference list have `None` logprobs.
# Validate this.
assert all([x is None for x in ref_logprobs])

# Next, assert that this is also true for
# test logprobs.
assert all([x is None for x in test_logprobs])
return

for (group_ref_logprobs,
group_test_logprobs) in zip(ref_logprobs, test_logprobs):

assert group_ref_logprobs is not None
assert group_test_logprobs is not None
assert len(group_ref_logprobs) == len(group_test_logprobs), (
"Reference & test logprob seq lens must match.")

for (token_ref_logprobs,
token_test_logprobs) in zip(group_ref_logprobs,
group_test_logprobs):
assert token_ref_logprobs.keys() == token_test_logprobs.keys(), (
"Reference & test top tokens must match.")
for (tok_str_ref,
tok_top_ref_logprob) in token_ref_logprobs.items():
tok_top_test_logprob = token_test_logprobs[tok_str_ref]

# Validate logprobs are numerically very close
assert abs(tok_top_test_logprob - tok_top_ref_logprob) <= (
atol + rtol * abs(tok_top_ref_logprob))


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1),
Expand All @@ -47,10 +159,37 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", NUM_LOGPROBS)
afeldman-nm marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
num_scheduler_steps: int, num_prompts: int,
num_logprobs: Optional[int]):
'''
Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.

Set up an engine with single-step scheduling as a ground-truth reference.

Send a completions API request to both engines with the same prompts.

Validate:
* Generated tokens match
* Generated logprobs are all very close

Arguments:

* example_prompts: test fixture providing example prompts
* model: model under test (same for single- and multi-step engines)
* tp_size: degree of tensor-parallelism
* pp_size: degree of pipeline-parallelism
* eager_mode
* num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
* num_prompts: number of example prompts under test
* num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
'''

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -73,13 +212,23 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
]

ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
prompts, model, server_args + distributed_args, num_logprobs)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)

def get_text_generations(completions):
return [x.text for x in completions.choices]
prompts, model, ms_server_args + distributed_args, num_logprobs)

# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
assert ref_generations == test_generations

# Assert multi-step scheduling produces nearly-identical logprobs
# to single-step scheduling.
ref_logprobs = get_logprob_generations(ref_completions)
test_logprobs = get_logprob_generations(test_completions)
assert_all_close_logprobs(
ref_logprobs,
test_logprobs,
atol=1e-5,
rtol=1e-5,
)
3 changes: 2 additions & 1 deletion tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytest
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.sequence import ExecuteModelRequest, Logprob
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
Expand Down
3 changes: 2 additions & 1 deletion tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import pytest
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
Expand Down
4 changes: 2 additions & 2 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
Expand Down
5 changes: 3 additions & 2 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)
CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)

from .core.utils import create_dummy_prompt

Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
Expand Down
6 changes: 3 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
PoolerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/output_processor/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Sequence as GenericSequence
from typing import Union

from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput


def create_output_by_sequence_group(
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer


Expand Down
3 changes: 2 additions & 1 deletion vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase
Expand Down
3 changes: 2 additions & 1 deletion vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest

logger = init_logger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest


class ExecutorBase(ABC):
Expand Down
Loading
Loading