Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fused weights for export_model #8554

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 70 additions & 43 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
GenerationInferenceModel,
)
from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel
from paddlenlp.transformers.conversion_utils import split_param_func

Check warning on line 50 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L50

Added line #L50 was not covered by tests
from paddlenlp.transformers.llama.modeling import LlamaLMHead
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -473,48 +474,66 @@
def set_state_dict(self, state_dict):
unfused_state_dict = {}
head_size = self.hidden_size // self.num_attention_heads
split_fn = split_param_func()

Check warning on line 477 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L477

Added line #L477 was not covered by tests

self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"]))
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"], dtype=self.norm.weight.dtype))
self.embed_tokens.weight.set_value(

Check warning on line 479 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L479

Added line #L479 was not covered by tests
paddle.to_tensor(state_dict["llama.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.cast 支持原始权重为bfloat16

)
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"]).cast(self.norm.weight.dtype))

Check warning on line 482 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L482

Added line #L482 was not covered by tests

for idx in range(self.config.num_hidden_layers):
logger.info(f"set state for layer {idx}")

if self.use_weight_only:
logger.info("weight only is enabled")
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
if "llama.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
concated_qkv_weight = np.concatenate(

Check warning on line 490 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L489-L490

Added lines #L489 - L490 were not covered by tests
split_fn(
state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)],
is_qkv=True,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
),
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
else:
unfused_state_dict = {}
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[

Check warning on line 501 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L500-L501

Added lines #L500 - L501 were not covered by tests
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[

Check warning on line 504 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L504

Added line #L504 was not covered by tests
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[

Check warning on line 507 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L507

Added line #L507 was not covered by tests
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]
concated_qkv_weight = (

Check warning on line 510 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L510

Added line #L510 was not covered by tests
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
self.hidden_size,
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
ffn1_weight_tensor = np.concatenate(

Check warning on line 526 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L525-L526

Added lines #L525 - L526 were not covered by tests
split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1
)
else:
unfused_state_dict["mlp.gate_proj.weight"] = state_dict[

Check warning on line 530 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L530

Added line #L530 was not covered by tests
"llama.layers.{}.mlp.gate_proj.weight".format(idx)
]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]
concated_ffn1_weight = np.concatenate(

Check warning on line 534 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L533-L534

Added lines #L533 - L534 were not covered by tests
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )

unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]

concated_ffn1_weight = np.concatenate(
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)

qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
Expand All @@ -534,7 +553,9 @@
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8")
)
else:
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor)
self.transformer_block.qkv_weights[idx].set_value(

Check warning on line 556 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L556

Added line #L556 was not covered by tests
qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype)
)

linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -556,7 +577,9 @@
)
)
else:
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor)
self.transformer_block.linear_weights[idx].set_value(

Check warning on line 580 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L580

Added line #L580 was not covered by tests
linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype)
)

if self.use_weight_only:
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
Expand All @@ -572,7 +595,9 @@
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
)
else:
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor)
self.transformer_block.ffn1_weights[idx].set_value(

Check warning on line 598 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L598

Added line #L598 was not covered by tests
ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype)
)

ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)])
if self.use_weight_only:
Expand All @@ -594,7 +619,9 @@
)
)
else:
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor)
self.transformer_block.ffn2_weights[idx].set_value(

Check warning on line 622 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L622

Added line #L622 was not covered by tests
ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype)
)

if self.quant_type == "a8w8":
if self.shift_smooth_all_linears:
Expand Down Expand Up @@ -660,16 +687,14 @@
)

self.transformer_block.ln_scales[idx].set_value(
paddle.to_tensor(
state_dict["llama.layers.{}.input_layernorm.weight".format(idx)],
dtype=self.transformer_block.ln_scales[idx].dtype,
paddle.to_tensor(state_dict["llama.layers.{}.input_layernorm.weight".format(idx)]).cast(
self.transformer_block.ln_scales[idx].dtype
)
)

self.transformer_block.ffn_ln_scales[idx].set_value(
paddle.to_tensor(
state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)],
dtype=self.transformer_block.ffn_ln_scales[idx].dtype,
paddle.to_tensor(state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)]).cast(
self.transformer_block.ffn_ln_scales[idx].dtype
)
)

Expand Down Expand Up @@ -1264,7 +1289,9 @@
@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.lm_head.weight.set_value(

Check warning on line 1292 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1292

Added line #L1292 was not covered by tests
paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype)
)
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


Expand Down
Loading