Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM] support Qwen2 #8338

Merged
merged 51 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
36ab9a7
add Qwen2Moe
DrownFish19 Apr 16, 2024
3913e11
update default config
DrownFish19 Apr 17, 2024
0aa1aca
Merge remote-tracking branch 'paddlenlp/develop' into dev_add_qwen1.5…
DrownFish19 Apr 17, 2024
a29e90d
update QWen2Moe modeling
DrownFish19 Apr 18, 2024
d514dff
update modeling
DrownFish19 Apr 18, 2024
1e98323
update ckpt name
DrownFish19 Apr 19, 2024
f81bb43
Merge branch 'PaddlePaddle:develop' into dev_add_qwen1.5-moe
DrownFish19 Apr 22, 2024
37dd2d5
support same prefix model name for auto modeling
DrownFish19 Apr 25, 2024
d12938a
update qwen2moe testing
DrownFish19 Apr 25, 2024
8cc49fc
update qwen2moe modeling and config
DrownFish19 Apr 25, 2024
9c8222e
update qwen2moe import
DrownFish19 Apr 25, 2024
4d6ff87
fix mlp hidden_size
DrownFish19 Apr 25, 2024
f350a2f
update qkv bias convert
DrownFish19 Apr 25, 2024
c53690d
update modeling init_weight
DrownFish19 Apr 25, 2024
9d12995
update _get_name_mappings
DrownFish19 Apr 25, 2024
dba0f74
update _get_name_mappings and _init_weight
DrownFish19 Apr 25, 2024
e487606
add tokenizer
DrownFish19 Apr 26, 2024
cd9c753
update modeling
DrownFish19 Apr 26, 2024
10407c4
update modeling
DrownFish19 Apr 26, 2024
beb0f4c
update tokenizer
DrownFish19 Apr 26, 2024
beefee9
update modeling and tokenizer
DrownFish19 Apr 28, 2024
82ba345
fix index_add_ error
DrownFish19 Apr 28, 2024
d522ee4
Merge branch 'PaddlePaddle:develop' into dev_add_qwen1.5-moe
DrownFish19 Apr 28, 2024
4a1b2e3
fix
DrownFish19 Apr 28, 2024
526a9db
Merge branch 'dev_add_qwen1.5-moe' of github.com:DrownFish19/PaddleNL…
DrownFish19 Apr 28, 2024
0c9d5ec
Merge branch 'PaddlePaddle:develop' into dev_add_qwen1.5-moe
DrownFish19 May 6, 2024
2bb3aba
update comments
DrownFish19 May 6, 2024
f203983
update lora weights
DrownFish19 May 10, 2024
58af3ec
add todo
DrownFish19 May 10, 2024
c766eb5
Merge branch 'PaddlePaddle:develop' into dev_add_qwen1.5-moe
DrownFish19 May 29, 2024
5ddc326
update Copyright
DrownFish19 May 29, 2024
de1db67
update Moe to MoE
DrownFish19 May 29, 2024
10a194c
Merge branch 'PaddlePaddle:develop' into dev_add_qwen1.5-moe
DrownFish19 May 30, 2024
87f0276
update comment
DrownFish19 May 30, 2024
8d9970b
update Copyright
DrownFish19 May 31, 2024
89994a6
Merge branch 'PaddlePaddle:develop' into dev_add_qwen1.5-moe
DrownFish19 Jun 3, 2024
d57a5b1
update readme and json
DrownFish19 Jun 3, 2024
bfb65a1
update __init__.py
DrownFish19 Jun 3, 2024
4b96dd0
add qwen-1.5
DrownFish19 Jun 4, 2024
b274f12
update QWen to Qwen
DrownFish19 Jun 5, 2024
1054f06
update Qwen2MoE to Qwen2Moe
DrownFish19 Jun 5, 2024
056b04c
update readme
DrownFish19 Jun 5, 2024
ab08c17
update qwen2moe sft and lora json
DrownFish19 Jun 5, 2024
ad02fdc
update qwen2moe base name
DrownFish19 Jun 5, 2024
23e39fc
update qwen2
DrownFish19 Jun 7, 2024
36b3897
update
DrownFish19 Jun 7, 2024
6455445
Merge branch 'PaddlePaddle:develop' into dev_add_qwen1.5-moe
DrownFish19 Jun 11, 2024
b140df6
update readme
DrownFish19 Jun 11, 2024
c08c9a6
Merge branch 'dev_add_qwen1.5-moe' of github.com:DrownFish19/PaddleNL…
DrownFish19 Jun 11, 2024
e6de5f3
update readme
DrownFish19 Jun 11, 2024
48ae2ab
update readme
DrownFish19 Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ def get_convert_example(model):

if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]:
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma", "qwen2moe"]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma"
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma",
"qwen2moe",
)


Expand Down
32 changes: 32 additions & 0 deletions llm/qwen2moe/lora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "qwen/Qwen1.5-MoE-A2.7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2moe_lora_ckpts",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认是否ok,并同步更新 readme 文档

"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 32768,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
30 changes: 30 additions & 0 deletions llm/qwen2moe/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "qwen/Qwen1.5-MoE-A2.7B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen2moe_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 32768,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"sharding": "stage2",
"pipeline_parallel_degree": 1
}
3 changes: 3 additions & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@
from .deberta_v2.modeling import *
from .deberta_v2.tokenizer import *
from .deberta_v2.configuration import *
from .qwen2moe.modeling import *
from .qwen2moe.configuration import *
from .qwen2moe.tokenizer import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .qwen2moe.modeling import *
from .qwen2moe.configuration import *
from .qwen2moe.tokenizer import *
from .qwen2moe import *


# For faster tokenizer
from ..utils.import_utils import is_fast_tokenizer_available
Expand Down
10 changes: 8 additions & 2 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
("Bloom", "bloom"),
("QWen", "qwen"),
("Mixtral", "mixtral"),
("QWen2Moe", "qwen2moe"),
("Gemma", "gemma"),
]
)
Expand Down Expand Up @@ -215,15 +216,20 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file
else:
init_class = config.pop("init_class", None)
init_class = init_class[:-5] if init_class is not None and init_class.endswith("Model") else init_class

# Sort the MAPPING_NAMES to reorder the model class names with longest-first rule
# thus the names with same prefix can be correctly inferred
# such as QWen and QWen2MOE, QWen2MOE is the longest prefix of QWen2MOEModel
model_name = None
SORTED_MAPPING_NAMES = dict(sorted(MAPPING_NAMES.items(), key=lambda x: len(x[0]), reverse=True))
if init_class:
for model_flag, name in MAPPING_NAMES.items():
for model_flag, name in SORTED_MAPPING_NAMES.items():
if model_flag in init_class:
model_name = model_flag + "Model"
break
else:
# From pretrained_model_name_or_path
for model_flag, name in MAPPING_NAMES.items():
for model_flag, name in SORTED_MAPPING_NAMES.items():
if name in pretrained_model_name_or_path.lower():
model_name = model_flag + "Model"
break
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/auto/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
("BloomTokenizer", "bloom"),
("SpeechT5Tokenizer", "speecht5"),
("QWenTokenizer", "qwen"),
("QWen2MoeTokenizer", "qwen2moe"),
("GemmaTokenizer", "gemma"),
]
)
Expand Down
17 changes: 17 additions & 0 deletions paddlenlp/transformers/qwen2moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration import QWen2MoeConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QWen2MoEConfig会不会更好,把Moe都改成MoE。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全部修改为Qwen2和Qwen2Moe,对齐hf

from .modeling import QWen2MoeForCausalLM
from .tokenizer import QWen2MoeTokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .configuration import QWen2MoeConfig
from .modeling import QWen2MoeForCausalLM
from .tokenizer import QWen2MoeTokenizer
from .configuration import *
from .modeling import *
from .tokenizer import*

203 changes: 203 additions & 0 deletions paddlenlp/transformers/qwen2moe/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Qwen2MoE model configuration"""

from paddlenlp.transformers.configuration_utils import PretrainedConfig

__all__ = [
"QWen2MoeConfig",
]


class QWen2MoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`QWen2MoeModel`]. It is used to instantiate a
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B").

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.


Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`QWen2MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `False`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.

```python
>>> from paddlenlp.transformers import QWen2MoeModel, QWen2MoeConfig

>>> # Initializing a Qwen2MoE style configuration
>>> configuration = QWen2MoeConfig()

>>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration
>>> model = QWen2MoeModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "qwen2moe"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
hidden_act="silu",
max_position_embeddings=8192,
seq_length=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
use_recompute=False,
recompute_granularity="full",
no_recompute_layers=None,
use_flash_attention=False,
attention_dropout=0.0,
use_fused_rope=False,
rope_theta=1000000.0,
tensor_parallel_output=True,
sequence_parallel=False,
fuse_sequence_parallel_allreduce=False,
pad_token_id=0,
bos_token_id=151643,
eos_token_id=151643,
tie_word_embeddings=False,
use_sliding_window=False,
sliding_window=32768,
max_window_layers=28,
decoder_sparse_step=1,
moe_intermediate_size=1408,
shared_expert_intermediate_size=5632,
num_experts_per_tok=4,
num_experts=60,
norm_topk_prob=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.seq_length = seq_length
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act

self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps

self.use_cache = use_cache
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.no_recompute_layers = no_recompute_layers
self.use_flash_attention = use_flash_attention
self.tensor_parallel_output = tensor_parallel_output
self.sequence_parallel = sequence_parallel
self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce

self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.use_fused_rope = use_fused_rope
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
tensor_parallel_output=tensor_parallel_output,
**kwargs,
)
Loading
Loading