Skip to content

Commit

Permalink
Fix tracking of model forward time in PP>1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mkeralapura committed Aug 12, 2024
1 parent 6aa33cb commit 91f253a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
5 changes: 0 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,11 +873,6 @@ def create_engine_config(self, ) -> EngineConfig:
raise ValueError(
f"Invalid module {m} in collect_detailed_traces. "
f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
if (m == "model"
or m == "all") and self.pipeline_parallel_size > 1:
raise ValueError(
"Collection of detailed traces for the 'model' module is "
"not yet supported with pipeline parallelism.")
observability_config = ObservabilityConfig(
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_model_forward_time="model" in detailed_trace_modules
Expand Down
34 changes: 33 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,21 @@ def execute_model(

# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states

logits = self.model.compute_logits(hidden_or_intermediate_states,
Expand All @@ -1576,6 +1591,23 @@ def execute_model(
# the communication time as well.
output.model_forward_time = model_forward_time

if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)

if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
Expand Down Expand Up @@ -1732,7 +1764,7 @@ def forward(
**kwargs)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
if key != "model_execute_time":
if key != "model_execute_time" and key != "model_forward_time":
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
# Run the graph.
Expand Down

0 comments on commit 91f253a

Please sign in to comment.