Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add span metrics for model_forward, scheduler and sampler time #7089

Merged
merged 23 commits into from
Aug 9, 2024

Conversation

sfc-gh-mkeralapura
Copy link
Contributor

@sfc-gh-mkeralapura sfc-gh-mkeralapura commented Aug 2, 2024

We are doing some work on understanding VLLM bottlenecks better. Towards that end having additional metrics/spans on where time is spent is useful.

After this change the span attributes will look something like (bolded are additional from this PR)

Attributes:
-> gen_ai.response.model: Str(/data-fast/s3/ml-dev-sfc-or-dev-misc1-k8s/yak/hf_models/meta-llama/Meta-Llama-3-70B)
-> gen_ai.request.id: Str(fcdae8c235f84b618c24b656f3ad4024)
-> gen_ai.request.temperature: Double(0)
-> gen_ai.request.top_p: Double(1)
-> gen_ai.request.max_tokens: Int(256)
-> gen_ai.request.best_of: Int(1)
-> gen_ai.request.n: Int(1)
-> gen_ai.usage.num_sequences: Int(1)
-> gen_ai.usage.prompt_tokens: Int(2106)
-> gen_ai.usage.completion_tokens: Int(256)
-> gen_ai.latency.time_in_queue: Double(0.012117147445678711)
-> gen_ai.latency.time_to_first_token: Double(0.2345738410949707)
-> gen_ai.latency.e2e: Double(3.775825023651123)
-> gen_ai.latency.time_in_scheduler: Double(0.017550230026245117)
-> gen_ai.latency.time_in_model_forward: Double(3.151565277099609)
-> gen_ai.latency.time_in_model_execute: Double(3.6468167304992676)

Had to put the model_forward time behind a flag since I could not get it working without the synchronize.
The synchronize does not seem to add a delay though. Overall, our expectation is that there should be no operations pending by the time the sampler is finished. Maybe if there is no token or in some other corner case there is something pending and elapsed_time() crashes with Cuda runtime error: device is busy.


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

github-actions bot commented Aug 2, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@sfc-gh-mkeralapura sfc-gh-mkeralapura changed the title Add span metrics for model_forward, scheduler and sampler time [Core] Add span metrics for model_forward, scheduler and sampler time Aug 2, 2024
Copy link
Contributor

@ronensc ronensc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Overall, this PR looks good to me. I've left a few comments. Once these are addressed, I think this will be ready to merge. Please make sure the CI tests pass. Thanks for your hard work!

vllm/sequence.py Show resolved Hide resolved
Comment on lines 1359 to 1361
model_forward_start = torch.cuda.Event(enable_timing=True)
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify if torch.cuda.Event() and Event.record() have any overhead when collect_model_forward_time=False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both from the docs it is not supposed (elapsed_time and its sync is the one that can have impact) to have any impact and I ran the benchmark_throughput.py (its measurements seem to have a variance of 5-10%) so did not notice anything signficant. I have run some of our custom internal benchmarking also and confirmed this is not of consequence

vllm/worker/worker_base.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made an initial pass!

I have several questions;

  • Does it work with PP now?
  • is it all-or-nothing collect trace?
  • should we also allow to log instead of collect traces (for simpler dev use cases)?

@@ -977,6 +977,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.time()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -660,6 +661,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=str,
default=None,
help='Target URL to which OpenTelemetry traces will be sent.')
parser.add_argument(
'--collect-model-forward-time',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about something like

--trace-components="model,scheduler," (it can be "all")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to put only this behind a flag is that it is possibly costly, with the use of a synchronize operation underneath. (fwiw, in our benchmarks it is not a factor). The others scheduler etc are costly. Reporting them to the traces gives more detail at no cost.

What do you think ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking more of the case to support logging this (it is useful for dev). But I think as long as it is just for tracing, I have no problem with that

@@ -1354,6 +1356,9 @@ def execute_model(
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
model_forward_start = torch.cuda.Event(enable_timing=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the overhead of this? should we make it opt-in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the Event creation and recording itself is in the noise. But the synchronize further could be theoretically costly for some combo of model architecture + request type.

Hence that part is protected by a flag. All the other ones - scheduling time, model_execute call are not protected since they are low overhead.

@@ -1379,6 +1385,16 @@ def execute_model(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it do cpu <> gpu sync?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a possiblity if I understand correctly. In practice though, once the sampling has finished, there should be no work pending on the gpu.

# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = model_forward_time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work with PP case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I test the PP case ? Can you give me a command line or model that I can do this with ? I have tested it with tensor parallelism on the llama models. I don't yet fully understand how all the code is laid out to be confident where the request processing joins in on the PP paths.

@zhisbug
Copy link
Collaborator

zhisbug commented Aug 6, 2024

@rkooo567 ci seems to look ok, let's merge?

@sfc-gh-mkeralapura
Copy link
Contributor Author

@rkooo567 ci seems to look ok, let's merge?

I got a configuration with pipeline parallelism working, I don't think the ObservabilityConfig is being correctly passed into the model runner. I will try and debug that tomorrow morning and respond.

Comment on lines 1395 to 1396
if (self.observability_config.collect_model_forward_time
and output is not None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (self.observability_config.collect_model_forward_time
and output is not None):
if (self.observability_config
and self.observability_config.collect_model_forward_time
and output is not None):

To address the CI error:

vllm/worker/model_runner.py:1395: error: Item "None" of "Optional[Any]" has no attribute "collect_model_forward_time"  [union-attr]

please add a check to ensure self.observability_config is not None. Additionally, run ./format.sh to fix any other linting issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a couple last comments!

default=EngineArgs.collect_model_forward_time,
help="If set to True and --otlp-traces-endpoint is set, "
"collects model forward time in traces. This involves "
"use of a blocking operation and hence might have a "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you actually test this? This is kind of interesting because tracing afaik is a production feature, but this will block it from being used in production...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have an internal setup at snowflake (where we collect these traces) and I do see the traces working including the model forward.

@@ -92,6 +92,9 @@ class SpanAttributes(BaseSpanAttributes):
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
LLM_LATENCY_E2E = "gen_ai.latency.e2e"
LLM_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
LLM_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
LLM_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the comments for these traces in this file? basically the same comment as

    # Total time spent in the forward pass for this across all workers
    model_forward_time: Optional[float] = None

    # Total time spent in the model execute function. This will include model
    # forward, block/sync across workers, cpu-gpu sync time and sampling time.
    model_execute_time: Optional[float] = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -1370,6 +1375,7 @@ def execute_model(
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
model_forward_end.record()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just not record when the feature is not enabled as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I measured the time, and it seems to take 100200 us to record (since we want to make e2e control plane overhead very small, 100200us seems pretty big for unnecessary work).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# If set, collects the model forward time for the request. This introduces a
# possibly blocking operation to accurately collect the GPU time. It can
# have a performance impact on the request latency.
collect_model_forward_time: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we raise an exception if it is true and tracing is not enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@comaniac
Copy link
Collaborator

comaniac commented Aug 8, 2024

Thanks for this useful feature! We should have a profiling tutorial in the vllm docs including this later.

@sfc-gh-mkeralapura
Copy link
Contributor Author

a couple last comments!

Thanks for the comments in both rounds.

The PP code path was not working. I have tested it on our internal benchmark framework. I accumulated model_execute_time through the Intermediate tensors. I think it should now be working. Model forward time I have not yet figured out how to do it.

Have put all the code behind config.collect_* vars. I have gone with the arg change you originally suggested with a comma separated string.

On the model_forward_time, in the PP path is not collecting the right metric - there are 3 options on how to progress - (1) figure how to do it correctly - will spend today trying that (2) Leave it as is since it works without PP and is useful to us but add a comment calling it out and fix in a follow up (3) Drop the model_forward change for this PR and follow up on it say next week. I will probably do (1) followed by (3).

@rkooo567
Copy link
Collaborator

rkooo567 commented Aug 8, 2024

On the model_forward_time, in the PP path is not collecting the right metric - there are 3 options on how to progress - (1) figure how to do it correctly - will spend today trying that (2) Leave it as is since it works without PP and is useful to us but add a comment calling it out and fix in a follow up (3) Drop the model_forward change for this PR and follow up on it say next week. I will probably do (1) followed by (3).

Can you just raise an exception if PP > 1 and this feature is enabled in this PR? the direction sounds great!

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's just raise an exception if PP is used with this feature for now! (or print warning and disable it automatically)

@sfc-gh-mkeralapura
Copy link
Contributor Author

On the model_forward_time, in the PP path is not collecting the right metric - there are 3 options on how to progress - (1) figure how to do it correctly - will spend today trying that (2) Leave it as is since it works without PP and is useful to us but add a comment calling it out and fix in a follow up (3) Drop the model_forward change for this PR and follow up on it say next week. I will probably do (1) followed by (3).

Can you just raise an exception if PP > 1 and this feature is enabled in this PR? the direction sounds great!

Done. Doing one more round of testing and then will try to merge.

@sfc-gh-mkeralapura
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 8, 2024
@sfc-gh-mkeralapura
Copy link
Contributor Author

Ah, looks like I can't merge. Can one of you with write access merge this in ?

Thanks a ton for the review & suggestions!

@zhisbug zhisbug merged commit 933790c into vllm-project:main Aug 9, 2024
45 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants