Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Use torch.compile for GemmaRMSNorm #7642

Merged
merged 3 commits into from
Aug 22, 2024
Merged

[Misc] Use torch.compile for GemmaRMSNorm #7642

merged 3 commits into from
Aug 22, 2024

Conversation

WoosukKwon
Copy link
Collaborator

This PR is a temporary solution to accelerate Gemma models. The PR can be reverted once #7110 is merged.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 19, 2024
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@WoosukKwon WoosukKwon mentioned this pull request Aug 19, 2024
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if this is a static function, I'm not sure this self would cause problem here.

if you want to be safe, I think you can move this function outside of the class definition.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be ok since it does not touch the states under self. I also checked that re-compilation does not happen after graph capturing, by monitoring the logs with TORCH_LOGS=guards. Also, the ShareGPT throughput benchmark shows 10~15% improvements.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as a temporary solution.

@WoosukKwon WoosukKwon merged commit b3856be into main Aug 22, 2024
39 of 41 checks passed
@WoosukKwon WoosukKwon deleted the gemma-rms branch August 22, 2024 08:14
omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants