Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test for Mixtral MoE layer #2677

Merged
merged 32 commits into from
Jan 31, 2024

Conversation

pcmoritz
Copy link
Collaborator

This is PR adds a unit test for the Mixtral MoE layer to vLLM.

It is based on @casper-hansen 's test in https://github.com/casper-hansen/AutoAWQ/blob/mixtral_fused/tests/test_fused_moe.py

@@ -141,8 +142,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
selected_experts,
inplace=True)

final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.tp_size > 1:
Copy link
Collaborator Author

@pcmoritz pcmoritz Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious what the thoughts on this are. On the one hand there is a ton of value in being able to run all the layers in a single process without having to stand up the distributed environment, on the other hand, this is redundant with the check in tensor_model_parallel_all_reduce that checks for get_tensor_model_parallel_world_size() == 1 (but we can't use that here since get_tensor_model_parallel_world_size already needs the distributed environment). I considered monkey patching that function but that doesn't seem great either.

@simon-mo
Copy link
Collaborator

https://buildkite.com/vllm/ci/builds/744#018d5d59-27fa-4a6c-bc07-f15bd8d544fc/51-61
Looks like they failed. The docker image is public, you can reproduce them by running on L4 instances in GCP.

@pcmoritz
Copy link
Collaborator Author

I understand the problem, it comes from using pytorch with CUDA and pytest-forked. I think the best solution (short of using pytest-xdist instead of pytest-forked, which is a big change) is to isolate the layer wise tests from the model test and just use pytest without --forked for the layer wise tests (since these should be fast).

@simon-mo
Copy link
Collaborator

simon-mo commented Jan 31, 2024 via email

@pcmoritz
Copy link
Collaborator Author

Ok sounds good let's do that then. If it gets too slow going forward, we can switch to pytest-xdist :)

@pcmoritz pcmoritz changed the title Add unit test for Mixtral MoE layer Add end-to-end test for Mixtral and unit test for Mixtral MoE layer Jan 31, 2024
@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Jan 31, 2024

I also added the end-to-end test for mixtral.

EDIT: It doesn't fit into the GPU memory, could have seen that coming.

@pcmoritz pcmoritz changed the title Add end-to-end test for Mixtral and unit test for Mixtral MoE layer Add unit test for Mixtral MoE layer Jan 31, 2024
@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Jan 31, 2024

I had to shift it to kernels which seems like a fine place to put it in. The models tests don't like it if somebody else is using their precious GPU memory :D

@casper-hansen
Copy link
Contributor

I’m just curious, is there any way this implementation can run FP16 and match the original implementation without a large difference in logits?

If not, do we understand why running both in FP16 gives such a large difference?

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Jan 31, 2024

Even in bfloat16 the difference is not that large. The best precision is in float32 without tensor cores (see https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere) and that's the best setting to check the correctness of the algorithm. But if we want the highest performance, we have to sacrifice some accuracy -- in the future we will probably offer the option to do this arithmetic in fp8, that will be even less accurate and then we probably have to do some scaling to get good results.

I don't think this algorithm is fundamentally less numerically stable than what we had before, it is all just a big blocked matrix multiplication. The only difference is this one is more cache efficient (and does less work because we are only doing the work for the experts that are actually used).

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Jan 31, 2024

Here are the numerical differences for different dtypes (the maximum absolute difference of the states after the MOE layer):

tensor(0.0005, device='cuda:0', dtype=torch.float16)
tensor(0.0029, device='cuda:0', dtype=torch.bfloat16)
tensor(0.0008, device='cuda:0') <- dtype = float32

(also note these numbers are a little random). And with float32 without tensor cores it is:

tensor(0.0002, device='cuda:0')

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Jan 31, 2024

Btw, one more point of comparison: Even if you take just the HuggingFace implementation, look at the difference of evaluating it in float16 vs. float32:

import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock


config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to("cuda")

inputs = torch.randn((1, 16, config.hidden_size)).to("cuda")

hf_states1, _ = hf_moe.forward(inputs)
hf_states2, _ = hf_moe.to(torch.float16).forward(inputs.to(torch.float16))

print("diff", torch.max(abs(hf_states1 - hf_states2.to(torch.float32))))

diff tensor(0.0003, device='cuda:0')

And this is almost the best possible case -- if you compare float16 weights on float16 inputs with bfloat16 weights on bfloat16 inputes, you get much worse (namely 0.0032 error). So we can actually be pretty happy about the accuracy we are getting. There is a decent amount of inherent inaccuracy in the problem itself :)

@pcmoritz
Copy link
Collaborator Author

@simon-mo This is now ready! I also fixed two failures in the kernels CI test along the way, one here (libcuda.so not found) because it is very related to this PR, and the other in #2684 which is unrelated to this PR.

@pcmoritz
Copy link
Collaborator Author

@casper-hansen I have also added tests for the other dtypes now :)

@simon-mo simon-mo merged commit d0d93b9 into vllm-project:main Jan 31, 2024
17 checks passed
NikolaBorisov pushed a commit to deepinfra/vllm that referenced this pull request Jan 31, 2024
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
alexm-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Feb 13, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants