Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Router Z-loss #151

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Router zloss
  • Loading branch information
josejg committed Sep 8, 2024
commit abc0638ef05f3b35f74368c5919838a9237dc80b
4 changes: 4 additions & 0 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class Arguments:
int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)

# Router Z-loss arguments
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
moe_zloss_in_fp32 : bool = False

def __post_init__(self):
if self.__getattribute__('mlp_impl') == 'grouped':
grouped_gemm.assert_grouped_gemm_is_available()
Expand Down
36 changes: 35 additions & 1 deletion megablocks/layers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@
from megablocks.layers import common
from megablocks.layers.arguments import Arguments

_ROUTER_LOGITS = []

def _save_router_logits(logits: torch.Tensor, args: Arguments):
if args.moe_zloss_weight == 0:
return
global _ROUTER_LOGITS
_ROUTER_LOGITS.append(logits)

def clear_router_zloss():
global _ROUTER_LOGITS
_ROUTER_LOGITS.clear()

def batched_router_zloss(args : Arguments):
global _ROUTER_LOGITS

if args.moe_zloss_weight == 0:
import warnings
warnings.warn("Call to batched_router_zloss, but moe_zloss_weight=0")
return 0

logits_per_router = _ROUTER_LOGITS

if args.moe_zloss_in_fp32:
logits_per_router = [logits.float() for logits in logits_per_router]

unscaled_zloss_per_router = torch.stack([
torch.logsumexp(logits, dim=1).square().mean()
for logits in logits_per_router
])

return args.moe_zloss_weight * unscaled_zloss_per_router


# NOTE: To enable end-to-end benchmarking without convergence we
# support a flag to force the router to assign tokens uniformly
Expand Down Expand Up @@ -60,7 +92,9 @@ def forward(self, x: torch.Tensor):
if self.training and self.args.moe_jitter_eps is not None:
x = x * self.jitter(x)

scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
logits = self.layer(x.view(-1, x.shape[-1]))
_save_router_logits(logits, self.args)
scores = logits.softmax(dim=-1)
expert_weights, expert_indices = self._top_k(scores)
if self.args.moe_normalize_expert_weights:
expert_weights = expert_weights / torch.norm(
Expand Down