Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Fix tracking of model forward time to the span traces in case of PP>1 #7440

Merged
merged 4 commits into from
Aug 16, 2024

Conversation

sfc-gh-mkeralapura
Copy link
Contributor

This is a quick follow up to #7089. In that PR we left the PP>1 unsupported for the model forward time. Fixing that here.

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@comaniac
Copy link
Collaborator

cc @rkooo567

vllm/worker/model_runner.py Show resolved Hide resolved
vllm/worker/model_runner.py Show resolved Hide resolved
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Leave to @rkooo567

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: is it possible to add unit tests for this?

orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we store this to tensor? any way to just use cpu data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I can tell, the only thing passed from the pipeline workers is the IntermediateTensors in serialized form. Hence added it to that. Is there a wrapper object of some form that holds these ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try a regular python object here to see if it works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. It looks like the worker serializes a Dict[Str, Any], so it can serialize floats too.

@sfc-gh-mkeralapura
Copy link
Contributor Author

QQ: is it possible to add unit tests for this?

Let me look into that. There is one just the reporting of these metrics and two the PP>1 case. Let me look into see how doable these are. I will circle back later in the day.

@sfc-gh-mkeralapura
Copy link
Contributor Author

QQ: is it possible to add unit tests for this?

I could not figure out how to get a unittest for this part of the worker. I instead added a test in the overall tracing test to test for these detailed trace data. It does not test for the pp>1 case though.

please take a look.

orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try a regular python object here to see if it works?

assert metrics.model_execute_time is None


def test_traces_with_detailed_steps(trace_service):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the same test with pp= 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a bit more involved, but done.

@sfc-gh-mkeralapura
Copy link
Contributor Author

@rkooo567 Are you comfortable with this PR on the whole ?

@rkooo567 rkooo567 enabled auto-merge (squash) August 15, 2024 20:43
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 15, 2024
@sfc-gh-mkeralapura
Copy link
Contributor Author

/ready

auto-merge was automatically disabled August 15, 2024 21:58

Head branch was pushed to by a user without write access

@zhisbug zhisbug enabled auto-merge (squash) August 16, 2024 18:06
@zhisbug zhisbug disabled auto-merge August 16, 2024 20:35
@youkaichao youkaichao merged commit 93478b6 into vllm-project:main Aug 16, 2024
67 of 70 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
…ct#7440)

[Core] Fix tracking of model forward time to the span traces in case of PP>1 (vllm-project#7440)
zifeitong pushed a commit to zifeitong/vllm that referenced this pull request Aug 20, 2024
…ct#7440)

[Core] Fix tracking of model forward time to the span traces in case of PP>1 (vllm-project#7440)
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
…ct#7440)

[Core] Fix tracking of model forward time to the span traces in case of PP>1 (vllm-project#7440)
omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
…ct#7440)

[Core] Fix tracking of model forward time to the span traces in case of PP>1 (vllm-project#7440)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants