Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test for Mixtral MoE layer #2677

Merged
merged 32 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
unify tests
  • Loading branch information
pcmoritz committed Jan 31, 2024
commit 5fb277b7249a5501bc3e2dfe0e0253daa000ee19
51 changes: 0 additions & 51 deletions tests/kernels/test_mixtral_moe.py

This file was deleted.

50 changes: 50 additions & 0 deletions tests/kernels/test_fused_moe.py → tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""

import pytest
import torch

from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.models.mixtral import MixtralMoE


def torch_moe(a, w1, w2, topk_weight, topk_ids):
Expand Down Expand Up @@ -48,3 +57,44 @@ def test_fused_moe(
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
"Make sure our Mixtral MoE implementation agrees with the one from huggingface."

# Instantiate our and huggingface's MoE blocks
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
vllm_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
params_dtype=dtype,
tp_size=1,
)

# Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
vllm_moe.ws[i][:] = torch.cat(weights, dim=0)
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data

# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")

# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(inputs)
vllm_states = vllm_moe.forward(inputs)

tol = {
torch.float32: 1e-3,
torch.float16: 1e-3,
torch.bfloat16: 1e-2,
}

assert torch.allclose(hf_states, vllm_states, rtol=tol[dtype], atol=tol[dtype])
Loading