Skip to content

Commit

Permalink
Reorder allreduce and residual add to enable better fusion (mlc-ai#941)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii authored Sep 26, 2023
1 parent 13263a6 commit 5902cc6
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ def forward(self, x):
up_result = self.up_proj(x)

result = self.down_proj(relax.op.nn.silu(gate_result) * up_result)
if self.num_shards > 1:
result = nn.emit(ccl.allreduce(result, "sum"))
return result


Expand Down Expand Up @@ -425,8 +423,6 @@ def forward(
)

attn_output = self.o_proj(attn_output)
if self.num_shards > 1:
attn_output = nn.emit(ccl.allreduce(attn_output, "sum"))
return attn_output, ((None, None) if past_key_value is None else past_key_value)


Expand Down Expand Up @@ -460,14 +456,21 @@ def forward(
attention_mask=attention_mask,
all_seq_len_shape=all_seq_len_shape,
)
if self.self_attn.num_shards > 1:
residual = nn.emit(residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype))
hidden_states = nn.emit(residual + hidden_states)

if self.self_attn.num_shards > 1:
hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum"))

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.mlp.num_shards > 1:
residual = nn.emit(residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype))
hidden_states = nn.emit(residual + hidden_states)

if self.mlp.num_shards > 1:
hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum"))
return hidden_states, present_key_value


Expand Down

0 comments on commit 5902cc6

Please sign in to comment.