Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix weight loading for Chameleon when TP>1 #7410

Merged
merged 9 commits into from
Aug 13, 2024

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Aug 12, 2024

This PR fixes the inability to run Chameleon (7B and 30B) with tensor parallelism.

  • Fixed TP weight loading for layer normalization in Chameleon model.
    • Added row_parallel_weight_loader (extracted from our code for Command-R model) to weight utils.
  • Handle the case where output logits can be None when applying TP logits processor in Chameleon model.
    • Fixed incorrect type annotations for TP logits processor (should return Optional[torch.Tensor] instead of torch.Tensor).
    • Handle _logits=None in compute_logits for Medusa model.
  • Added missing vLLM vs HF consistency tests for Chameleon model.
    • Added postprocess_inputs to HfRunner to convert input dtypes for Chameleon, since its HF processor fails to do so.
      • Updated Mantis and MiniCPM-V tests to use postprocess_inputs instead of patching the processor directly.
    • Fixed incorrect type of output_ids from VllmRunner causing failures when using check_outputs_equal.
      • Fixed incorrect type annotations in CompletionOutput (token_ids is now an array instead of a tuple).
  • Added distributed test for Chameleon model to ensure that it can be run with tensor parallelism.

FIX #7388

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@jaywonchung
Copy link
Contributor

I backported changes in this PR to my fork based on v0.5.4, and it resolves #7388 for me. My fork isn't like a significant deviation from v0.5.4, so I think this fixes it. Thanks a lot!

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 Is this something new that _get_logits can return None value? Why wasn't this caught previously?

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Aug 13, 2024

It appears that Chameleon is the first model that actually uses the result of calling LogitsProcessor before returning it, so other models have not triggered this error.

Edit: Actually, Phi also uses it, but it already has a None check. On the other hand, Medusa fails to check for None - going to fix it.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! Thank you for the fix!

vllm/model_executor/model_loader/weight_utils.py Outdated Show resolved Hide resolved
@ywang96
Copy link
Member

ywang96 commented Aug 13, 2024

It appears that Chameleon is the first model that actually uses the result of calling LogitsProcessor before returning it, so other models have not triggered this error.

Edit: Actually, Phi also uses it, but it has a None check.

Yea I checked again and for most models we return logits directly and let the downstream function to handle the gather. Good find!

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 13, 2024
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 13, 2024 03:38
@DarkLight1337 DarkLight1337 merged commit 7025b11 into vllm-project:main Aug 13, 2024
54 checks passed
@DarkLight1337 DarkLight1337 deleted the fix-chameleon-tp branch August 13, 2024 05:35
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: facebook/chameleon-30b triggers assertion error while loading weights
3 participants