Skip to content

Commit

Permalink
support fused weights for export_model
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed Jun 5, 2024
1 parent f36ed75 commit 8dc2cf7
Showing 1 changed file with 41 additions and 38 deletions.
79 changes: 41 additions & 38 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,47 +474,50 @@ def set_state_dict(self, state_dict):
unfused_state_dict = {}
head_size = self.hidden_size // self.num_attention_heads

self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"]))
self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"], dtype=self.embed_tokens.weight.dtype))

Check warning on line 477 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L477

Added line #L477 was not covered by tests
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"], dtype=self.norm.weight.dtype))

for idx in range(self.config.num_hidden_layers):
logger.info(f"set state for layer {idx}")

if self.use_weight_only:
logger.info("weight only is enabled")
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
if "llama.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
concated_qkv_weight = state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)].transpose([1, 0])

Check warning on line 486 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L485-L486

Added lines #L485 - L486 were not covered by tests
else:
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[

Check warning on line 489 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L488-L489

Added lines #L488 - L489 were not covered by tests
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[

Check warning on line 492 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L492

Added line #L492 was not covered by tests
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[

Check warning on line 495 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L495

Added line #L495 was not covered by tests
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]
concated_qkv_weight = (

Check warning on line 498 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L498

Added line #L498 was not covered by tests
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
concated_ffn1_weight = state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]

Check warning on line 514 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L513-L514

Added lines #L513 - L514 were not covered by tests
else:
unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]
concated_ffn1_weight = np.concatenate(

Check warning on line 518 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L516-L518

Added lines #L516 - L518 were not covered by tests
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )

unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]

concated_ffn1_weight = np.concatenate(
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)

qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
Expand All @@ -534,7 +537,7 @@ def set_state_dict(self, state_dict):
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8")
)
else:
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor)
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype))

Check warning on line 540 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L540

Added line #L540 was not covered by tests

linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -556,7 +559,7 @@ def set_state_dict(self, state_dict):
)
)
else:
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor)
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype))

Check warning on line 562 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L562

Added line #L562 was not covered by tests

if self.use_weight_only:
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
Expand All @@ -572,7 +575,7 @@ def set_state_dict(self, state_dict):
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
)
else:
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor)
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype))

Check warning on line 578 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L578

Added line #L578 was not covered by tests

ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -594,7 +597,7 @@ def set_state_dict(self, state_dict):
)
)
else:
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor)
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype))

Check warning on line 600 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L600

Added line #L600 was not covered by tests

if self.quant_type == "a8w8":
if self.shift_smooth_all_linears:
Expand Down Expand Up @@ -1264,7 +1267,7 @@ def forward(
@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.lm_head.weight.set_value(paddle.to_tensor(state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype))

Check warning on line 1270 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1270

Added line #L1270 was not covered by tests
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


Expand Down

0 comments on commit 8dc2cf7

Please sign in to comment.