Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused MOE for Mixtral #2542

Merged
merged 23 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint and cleanup
  • Loading branch information
pcmoritz committed Jan 22, 2024
commit 2867f346274c1f850597a42f447b2ef7e4266592
4 changes: 1 addition & 3 deletions vllm/model_executor/layers/moe.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcmoritz Should we move the MoEclass back to the Mixtral model file? It seems like this MoE layer is not shared between Mixtral and DeepSeek.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! Feel free to make any edits to the PR you'd like to make or let me know if I should make them :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd appreciate it if you can do it!

Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch
from torch import nn
import torch.nn.functional as F
Expand Down Expand Up @@ -429,4 +427,4 @@ def fused_moe(hidden_states: torch.Tensor,
dim=1,
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
dim=1)
3 changes: 1 addition & 2 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -272,7 +271,7 @@ def load_weights(self,

expert_params_mapping = [
# (param_name, weight_name, expert_id)
(f"ws" if weight_name in ["w1", "w3"] else "w2s",
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
Expand Down
Loading