Skip to content

Commit

Permalink
[Kernel] Update fused_moe tuning script for FP8 (vllm-project#4457)
Browse files Browse the repository at this point in the history
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.

All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.

Before this PR (with static activation scaling):

qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency 
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency

After this PR (with static activation scaling):

qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
  • Loading branch information
pcmoritz authored and dtrifiro committed May 7, 2024
1 parent 12d6ba8 commit bd6c7cc
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 38 deletions.
109 changes: 71 additions & 38 deletions benchmarks/kernels/benchmark_mixtral_moe.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import argparse
import json
import os
import sys

import torch
import torch.nn.functional as F
import triton
from tqdm import tqdm

from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def main():
def main(dtype: str):
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
run_grid(bs, method=method)
run_grid(bs, method=method, dtype=dtype)


def run_grid(bs, method):
def run_grid(bs, method, dtype: str):
d_model = 4096
num_total_experts = 8
top_k = 2
Expand All @@ -34,39 +36,29 @@ def run_grid(bs, method):
num_trials = 1

configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]

for block_size_n in [32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M:
for block_size_m in [16, 32, 64, 128, 256]:
for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": 4,
})
for num_stages in [2, 3, 4, 5]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
})

best_config = None
best_time_us = 1e20

for config in configs:
print(f'{tp_size=} {bs=}')
print(f'{config}')
print(f'{tp_size=} {bs=}')

for config in tqdm(configs):
# warmup
print('warming up')
try:
for _ in range(num_warmup_trials):
run_timing(
Expand All @@ -79,12 +71,12 @@ def run_grid(bs, method):
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
)
except triton.runtime.autotuner.OutOfResources:
continue

# trial
print('benchmarking')
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
Expand All @@ -96,6 +88,7 @@ def run_grid(bs, method):
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
)

kernel_dur_us = 1000 * kernel_dur_ms
Expand All @@ -105,16 +98,18 @@ def run_grid(bs, method):
best_config = config
best_time_us = kernel_dur_us

print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}')
tqdm.write(
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}')

print("best_time_us", best_time_us)
print("best_config", best_config)

# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(num_total_experts,
model_intermediate_size // tp_size)
model_intermediate_size // tp_size,
"float8" if dtype == "float8" else None)
print(f"writing config to file {filename}")
existing_content = {}
if os.path.exists(filename):
Expand All @@ -128,27 +123,48 @@ def run_grid(bs, method):

def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float:
config, dtype: str) -> float:
shard_intermediate_size = model_intermediate_size // tp_size

hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.bfloat16,
dtype=torch.float16,
)

ws = torch.rand(
w1 = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

w2s = torch.rand(
w2 = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None

if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
w2_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
a1_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
a2_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)

gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
Expand All @@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=ws,
w2=w2s,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
gating_output=gating_output[i],
topk=2,
renormalize=True,
inplace=True,
override_config=config,
use_fp8=dtype == "float8",
)
end_event.record()
end_event.synchronize()
Expand All @@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,


if __name__ == "__main__":
sys.exit(main())
parser = argparse.ArgumentParser(
prog='benchmark_mixtral_moe',
description='Benchmark and tune the fused_moe kernel',
)
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['float8', 'float16'],
help='Data type used for fused_moe kernel computations',
)
args = parser.parse_args()
sys.exit(main(args.dtype))
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

0 comments on commit bd6c7cc

Please sign in to comment.