Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin #7701

Merged
merged 27 commits into from
Sep 23, 2024
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
458d69e
squash-patch changes
LucasWilkinson Jul 31, 2024
1ee3608
remove gptq support
LucasWilkinson Aug 30, 2024
ab7507e
formatting + fixes
LucasWilkinson Aug 30, 2024
68ff26d
add gptq_marlin support back
LucasWilkinson Aug 31, 2024
7b9e8b2
remove extra prints
LucasWilkinson Aug 31, 2024
30f1056
add machete act ordering
LucasWilkinson Sep 6, 2024
3bbb902
udpate heuristic
LucasWilkinson Sep 6, 2024
196a9f2
add to tests
LucasWilkinson Sep 6, 2024
38f5b84
update benchmark
LucasWilkinson Sep 6, 2024
c59449b
tweak for llama 405b
LucasWilkinson Sep 6, 2024
3048911
env var for disabling kernels
LucasWilkinson Sep 10, 2024
df7c4c0
format + mypy
LucasWilkinson Sep 11, 2024
6f3f707
yapf format
LucasWilkinson Sep 11, 2024
90b8e03
refactor
LucasWilkinson Sep 11, 2024
c264c7a
add g_idx back
LucasWilkinson Sep 11, 2024
2d25a9a
clean-up
LucasWilkinson Sep 11, 2024
62508c5
review comments
LucasWilkinson Sep 12, 2024
84cfdb2
fix codespell
LucasWilkinson Sep 12, 2024
c452a86
TorchDynamo Compatability
LucasWilkinson Sep 13, 2024
096dd4a
add permute cols opcheck
LucasWilkinson Sep 13, 2024
a98f691
fix correctness test
LucasWilkinson Sep 16, 2024
7c02bcf
bug in filtering kernels by compute capability
LucasWilkinson Sep 16, 2024
95a85c9
Merge remote-tracking branch 'origin/main' into lwilkinson/machete-en…
LucasWilkinson Sep 20, 2024
a019473
add requirements.txt
LucasWilkinson Sep 20, 2024
306b283
Merge branch 'main' into lwilkinson/machete-end2end
mgoin Sep 21, 2024
e32bfc5
[dbrx] refactor dbrx experts to extend FusedMoe class (#8518)
divakar-amd Sep 21, 2024
05752e9
[Kernel][Bugfix] Delete some more useless code in marlin_moe_ops.cu (…
tlrmchlsmth Sep 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[dbrx] refactor dbrx experts to extend FusedMoe class (#8518)
  • Loading branch information
divakar-amd authored and LucasWilkinson committed Sep 22, 2024
commit e32bfc5bb053ba2a57f2cb3ef61e3919f025071d
120 changes: 51 additions & 69 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
Expand All @@ -22,7 +21,6 @@
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig

Expand Down Expand Up @@ -54,63 +52,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return router_logits


class DbrxExperts(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
class DbrxExperts(FusedMoE):

def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
super().__init__(
num_experts=config.ffn_config.moe_num_experts,
top_k=config.ffn_config.moe_top_k,
hidden_size=config.d_model,
intermediate_size=config.ffn_config.ffn_hidden_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
self.d_model = config.d_model
self.intermediate_size = (config.ffn_config.ffn_hidden_size //
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
self.tp_size)

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

self.router = DbrxRouter(config, self.params_dtype)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.d_model,
device="cuda",
dtype=self.params_dtype,
))
self.w2s = nn.Parameter(
torch.empty(
self.num_total_experts,
self.d_model,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
))

set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)

# Define custom weight loader for dbrx model
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -140,26 +107,40 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]


class DbrxMoE(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""

def __init__(
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.d_model = config.d_model
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

self.router = DbrxRouter(config, self.params_dtype)

self.experts = DbrxExperts(config=config,
quant_config=quant_config,
params_dtype=self.params_dtype)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states.view(num_tokens, hidden_size)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)


class DbrxAttention(nn.Module):
Expand Down Expand Up @@ -288,7 +269,7 @@ def __init__(
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
quant_config)
self.ffn = DbrxExperts(config, quant_config)
self.ffn = DbrxMoE(config, quant_config)

def forward(
self,
Expand Down Expand Up @@ -409,9 +390,10 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
Expand Down