Skip to content

Commit

Permalink
renamed outputs to cached_outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 19, 2024
1 parent d80374f commit 2a07b6c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 63 deletions.
16 changes: 6 additions & 10 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import (
StatefulModelInput)
from vllm.worker.multi_step_model_runner import StatefulModelInput


class MockAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -186,24 +185,21 @@ def test_multi_step_model_runner_input():
is_multi_step=True,
num_queries=8,
num_seqs=5,
outputs=[],
cached_outputs=[],
)

assert isinstance(model_input,
StatefulModelInput)
assert isinstance(model_input, StatefulModelInput)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (StatefulModelInput.
from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))

receieved_frozen_input = received_model_input.frozen_model_input

# Check that received copy has correct values.
assert isinstance(received_model_input,
StatefulModelInput)
assert isinstance(received_model_input, StatefulModelInput)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
Expand Down
64 changes: 27 additions & 37 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,28 @@ class ModelOutput:
sampled_token_ids: Optional[torch.Tensor] = None
pythonized: bool = False

def pythonize(
self,
input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output. Blocking."""
if not self.pythonized:
self._pythonize_sampler_output(input_metadata, copy_stream,
pinned_sampled_token_buffer, True)
self.pythonized = True

def maybe_pythonize(
self,
input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
def maybe_pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output if ready, else return None. Non-blocking."""
if not self.pythonized:
self.pythonized = self._pythonize_sampler_output(
input_metadata, copy_stream, pinned_sampled_token_buffer,
False)

def _pythonize_sampler_output(
self,
input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor, blocking: bool) -> bool:
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor,
blocking: bool) -> bool:
"""
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output.
Expand All @@ -102,7 +97,7 @@ class StatefulModelInput(BroadcastableModelInput):
frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None

# list of model outputs for each step, may not be all pythonized
outputs: List[ModelOutput] = field(default_factory=list)
cached_outputs: List[ModelOutput] = field(default_factory=list)

# used to pass sampled token ids from the last step to the current step for
# TP workers. Used to append to end of outputs and used by advance_step
Expand Down Expand Up @@ -170,7 +165,7 @@ def wait_previous_step(self):
def add_sampler_output(self,
sampler_output: SamplerOutput,
sampled_token_ids: Optional[torch.Tensor] = None):
self.outputs.append(
self.cached_outputs.append(
ModelOutput(sampler_output=sampler_output,
sampler_output_ready_event=None,
sampled_token_ids=sampled_token_ids,
Expand All @@ -181,8 +176,7 @@ def add_sampler_output(self,
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
# metadata
# mypy: disable-error-code=type-var
class MultiStepModelRunner(
GPUModelRunnerBase[StatefulModelInput]):
class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# mypy: enable-error-code=type-var

def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
Expand All @@ -198,13 +192,11 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None

def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]
) -> StatefulModelInput:
model_input = (StatefulModelInput.
from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
return model_input

def prepare_model_input(
Expand Down Expand Up @@ -277,7 +269,7 @@ def execute_model(
# far ahead if needed)
model_input.wait_previous_step()
model_input = self._advance_step(
model_input, model_input.outputs[-1].sampler_output)
model_input, model_input.cached_outputs[-1].sampler_output)

# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
Expand All @@ -300,7 +292,7 @@ def execute_model(
if self.parallel_config.pipeline_parallel_size > 1:
output[0].sampled_token_ids_cpu = output[
0].sampled_token_ids.cpu()
model_input.outputs.append(
model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False))
# make sure we dont try to serialize any GPU tensors
Expand All @@ -309,7 +301,7 @@ def execute_model(
output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
for model_output in model_input.outputs:
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)

Expand All @@ -325,7 +317,7 @@ def execute_model(
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = []
for output in model_input.outputs:
for output in model_input.cached_outputs:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
Expand Down Expand Up @@ -353,10 +345,8 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode

def _advance_step(
self, model_input: StatefulModelInput,
out: SamplerOutput
) -> StatefulModelInput:
def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
Expand All @@ -377,7 +367,7 @@ def _advance_step(
num_queries=num_queries,
block_size=self.block_size,
input_tokens=frozen_model_input.input_tokens,
sampled_token_ids=model_input.outputs[-1].sampled_token_ids,
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
input_positions=frozen_model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
Expand Down Expand Up @@ -420,8 +410,8 @@ def vocab_size(self) -> int:


def _pythonize_sampler_output(
model_input: StatefulModelInput,
output: SamplerOutput, pinned_sampled_token_buffer: torch.Tensor,
model_input: StatefulModelInput, output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor) -> SamplerOutput:
""" This function is only called when the output tensors are ready.
See ModelOutput
Expand Down
28 changes: 12 additions & 16 deletions vllm/worker/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (
MultiStepModelRunner, StatefulModelInput)
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
StatefulModelInput)
from vllm.worker.worker import Worker, WorkerInput


Expand Down Expand Up @@ -99,15 +99,15 @@ def _prepare_last_sampled_token_ids_for_tp_workers(
token that is cached in the execute_model_req.
"""
if get_pp_group().is_last_rank:
assert model_input.outputs[
assert model_input.cached_outputs[
-1].sampler_output.sampled_token_ids is None
assert model_input.outputs[-1].sampled_token_ids is not None
model_input.last_sampled_token_ids = model_input.outputs[
assert model_input.cached_outputs[-1].sampled_token_ids is not None
model_input.last_sampled_token_ids = model_input.cached_outputs[
-1].sampled_token_ids
# free sampled token ids from the previous step if it has been
# pythonized. Cannot free the last sampled token ids because
# we need it for GPU advance_step.
for output in model_input.outputs[:-1]:
for output in model_input.cached_outputs[:-1]:
if output.pythonized:
output.sampled_token_ids = None
else:
Expand All @@ -123,15 +123,14 @@ def _prepare_last_sampled_token_ids_for_tp_workers(
# free sampled token ids from the previous step.
# TODO(will) we could reuse the sampled token ids tensor from
# the previous step instead.
for output in model_input.outputs[:-1]:
for output in model_input.cached_outputs[:-1]:
output.sampled_token_ids = None
assert model_input.outputs[-1].sampled_token_ids is not None
assert model_input.cached_outputs[-1].sampled_token_ids is not None

def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[StatefulModelInput,
WorkerInput]]:
) -> Optional[Tuple[StatefulModelInput, WorkerInput]]:
"""
Depending on the current state of the request and multi step worker,
this method may skip the normal _prepare_model_input and
Expand All @@ -151,8 +150,7 @@ def prepare_input(
virtual_engine = execute_model_req.virtual_engine
model_input, worker_input = self._get_driver_input_and_broadcast(
execute_model_req)
assert isinstance(model_input,
StatefulModelInput)
assert isinstance(model_input, StatefulModelInput)
if execute_model_req.is_first_multi_step:
# cache the worker input and model input for the next steps
self.multi_step_states[virtual_engine] = MultiStepState(
Expand All @@ -165,8 +163,7 @@ def prepare_input(
if broadcast_data is None:
return None
model_input, worker_input = broadcast_data
assert isinstance(model_input,
StatefulModelInput)
assert isinstance(model_input, StatefulModelInput)
virtual_engine = worker_input.virtual_engine
if model_input.is_first_multi_step:
pass
Expand All @@ -179,8 +176,7 @@ def prepare_input(
# the model input states and we only broadcast the delta need
# for the next step (sampled_token_ids from the previous step)

assert isinstance(
model_input, StatefulModelInput)
assert isinstance(model_input, StatefulModelInput)
# we need to update the last sampled token ids in the model
# input for the workers so that they can run inplace
# advance_step
Expand Down

0 comments on commit 2a07b6c

Please sign in to comment.