Skip to content

Commit

Permalink
Support mistral nemo series models (modelscope#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Jul 20, 2024
1 parent 245234f commit bfe509c
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ You can contact us and communicate with us by adding our group:
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">

## 🎉 News
- 2024.07.20: Support mistral-nemo series models. Use `--model_type mistral-nemo-base-2407` and `--model_type mistral-nemo-instruct-2407` to begin.
- 2024.07.19: Support [Q-Galore](https://arxiv.org/abs/2407.08296), this algorithm can reduce the training memory cost by 60% (qwen-7b-chat, full, 80G -> 35G), use `swift sft --model_type xxx --use_galore true --galore_quantization true` to begin!
- 2024.07.17: Support newly released InternVL2 models: `model_type` are internvl2-1b, internvl2-40b, internvl2-llama3-76b. For best practices, refer to [here](docs/source_en/Multi-Modal/internvl-best-practice.md).
- 2024.07.17: Support the training and inference of [NuminaMath-7B-TIR](https://huggingface.co/AI-MO/NuminaMath-7B-TIR). Use with model_type `numina-math-7b`.
Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:


## 🎉 新闻
- 2024.07.20: 支持mistral-nemo系列模型. 使用`--model_type mistral-nemo-base-2407`以及`--model_type mistral-nemo-instruct-2407`开始训练和推理.
- 🔥2024.07.19: 支持[Q-Galore](https://arxiv.org/abs/2407.08296)算法, 该算法可以减少显存使用约60% (qwen-7b-chat, full, 80G -> 35G), 使用命令行:`swift sft --model_type xxx --use_galore true --galore_quantization true`来开始训练!
- 2024.07.17: 支持InternVL2系列新模型: `model_type`分别为internvl2-1b, internvl2-40b, internvl2-llama3-76b. 最佳实践可以查看[这里](docs/source/Multi-Modal/internvl最佳实践.md).
- 2024.07.17: 支持[NuminaMath-7B-TIR](https://www.modelscope.cn/models/AI-ModelScope/NuminaMath-7B-TIR)的训练和推理. model_type可以使用`numina-math-7b`.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/LLM/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@
|mistral-7b-instruct|[AI-ModelScope/Mistral-7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)|
|mistral-7b-instruct-v2|[AI-ModelScope/Mistral-7B-Instruct-v0.2](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.2/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|
|mistral-7b-instruct-v3|[LLM-Research/Mistral-7B-Instruct-v0.3](https://modelscope.cn/models/LLM-Research/Mistral-7B-Instruct-v0.3/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)|
|mistral-nemo-instruct-2407|[AI-ModelScope/Mistral-Nemo-Instruct-2407](https://modelscope.cn/models/AI-ModelScope/Mistral-Nemo-Instruct-2407/summary)|q_proj, k_proj, v_proj|mistral-nemo|&#x2714;|&#x2714;|transformers>=4.43.0.dev0|-|[mistralai/Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)|
|mistral-nemo-base-2407|[AI-ModelScope/Mistral-Nemo-Base-2407](https://modelscope.cn/models/AI-ModelScope/Mistral-Nemo-Base-2407/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.43.0.dev0|-|[mistralai/Mistral-Nemo-Base-2407](https://huggingface.co/mistralai/Mistral-Nemo-Base-2407)|
|mixtral-moe-7b|[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.36|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)|
|mixtral-moe-7b-instruct|[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.36|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)|
|mixtral-moe-7b-aqlm-2bit-1x16|[AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2718;|transformers>=4.38, aqlm, torch>=2.2.0|-|[ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://huggingface.co/ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf)|
Expand Down
2 changes: 2 additions & 0 deletions docs/source_en/LLM/Supported-models-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ The table below introcudes all models supported by SWIFT:
|mistral-7b-instruct|[AI-ModelScope/Mistral-7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)|
|mistral-7b-instruct-v2|[AI-ModelScope/Mistral-7B-Instruct-v0.2](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.2/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|
|mistral-7b-instruct-v3|[LLM-Research/Mistral-7B-Instruct-v0.3](https://modelscope.cn/models/LLM-Research/Mistral-7B-Instruct-v0.3/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)|
|mistral-nemo-instruct-2407|[AI-ModelScope/Mistral-Nemo-Instruct-2407](https://modelscope.cn/models/AI-ModelScope/Mistral-Nemo-Instruct-2407/summary)|q_proj, k_proj, v_proj|mistral-nemo|&#x2714;|&#x2714;|transformers>=4.43.0.dev0|-|[mistralai/Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)|
|mistral-nemo-base-2407|[AI-ModelScope/Mistral-Nemo-Base-2407](https://modelscope.cn/models/AI-ModelScope/Mistral-Nemo-Base-2407/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.43.0.dev0|-|[mistralai/Mistral-Nemo-Base-2407](https://huggingface.co/mistralai/Mistral-Nemo-Base-2407)|
|mixtral-moe-7b|[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.36|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)|
|mixtral-moe-7b-instruct|[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|&#x2714;|&#x2714;|transformers>=4.36|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)|
|mixtral-moe-7b-aqlm-2bit-1x16|[AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2718;|transformers>=4.38, aqlm, torch>=2.2.0|-|[ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://huggingface.co/ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf)|
Expand Down
20 changes: 20 additions & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ class ModelType:
mistral_7b_instruct = 'mistral-7b-instruct'
mistral_7b_instruct_v2 = 'mistral-7b-instruct-v2'
mistral_7b_instruct_v3 = 'mistral-7b-instruct-v3'
mistral_nemo_instruct_2407 = 'mistral-nemo-instruct-2407'
mistral_nemo_base_2407 = 'mistral-nemo-base-2407'
mixtral_moe_7b = 'mixtral-moe-7b'
mixtral_moe_7b_instruct = 'mixtral-moe-7b-instruct'
mixtral_moe_7b_aqlm_2bit_1x16 = 'mixtral-moe-7b-aqlm-2bit-1x16' # aqlm
Expand Down Expand Up @@ -2505,6 +2507,24 @@ def _output_device_map_hook(module, input, output):
support_flash_attn=True,
support_vllm=True,
hf_model_id='mistral-community/Mixtral-8x22B-v0.1')
@register_model(
ModelType.mistral_nemo_instruct_2407,
'AI-ModelScope/Mistral-Nemo-Instruct-2407',
LoRATM.llama,
TemplateType.mistral_nemo,
requires=['transformers>=4.43.0.dev0'],
support_flash_attn=True,
support_vllm=True,
hf_model_id='mistralai/Mistral-Nemo-Instruct-2407')
@register_model(
ModelType.mistral_nemo_base_2407,
'AI-ModelScope/Mistral-Nemo-Base-2407',
LoRATM.llama,
TemplateType.default_generation,
requires=['transformers>=4.43.0.dev0'],
support_flash_attn=True,
support_vllm=True,
hf_model_id='mistralai/Mistral-Nemo-Base-2407')
@register_model(
ModelType.dbrx_base,
'AI-ModelScope/dbrx-base',
Expand Down
14 changes: 11 additions & 3 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class TemplateType:
llama_llava_next = 'llama-llava-next'
llava_next_video = 'llava-next-video'
llava_next_video_yi = 'llava-next-video-yi'
mistral_nemo = 'mistral-nemo'
openbuddy = 'openbuddy'
openbuddy2 = 'openbuddy2'
internlm = 'internlm'
Expand Down Expand Up @@ -187,7 +188,7 @@ def __init__(self,
prefix = self._replace_system(prefix)
self.prefix = prefix
self.system_prefix = system_prefix
if self.system_prefix is None:
if self.system_prefix is None and not any(['{{SYSTEM}}' in context for context in prompt]):
assert default_system is None, 'The template does not support `system`.'
self.prompt = prompt
self.chat_sep = chat_sep
Expand Down Expand Up @@ -561,7 +562,10 @@ def _encode(self,
if isinstance(bos_token_id, int) and bos_token_id in self.tokenizer.encode(''):
res_context_list.append([bos_token_id])
loss_scale_list.append(0.)
prompt = self.prompt.copy()
if system is None:
prompt = [context for context in prompt if '{{SYSTEM}}' not in context]
if system is None or any(['{{SYSTEM}}' in context for context in prompt]):
prefix = self.prefix
else:
prefix = self.system_prefix
Expand All @@ -571,8 +575,9 @@ def _encode(self,
history_roles.append([query_role, 'assistant'])

for i, ((q, r), (qr, rr)) in enumerate(zip(history, history_roles)):
context_list = self.tool_prompt.copy() if qr == 'tool' else self.prompt.copy()
context_list = self.tool_prompt.copy() if qr == 'tool' else prompt.copy()
if i < len(history) - 1:
context_list = [context for context in context_list if '{{SYSTEM}}' not in context]
context_list.append('{{RESPONSE}}')
if history[i + 1][0]:
context_list += self.chat_sep
Expand All @@ -582,7 +587,7 @@ def _encode(self,
context_list += self.suffix
if q or r:
self._concat_context_list(
context_list, res_context_list, loss_scale_list, query=q, response=r, round0=i)
context_list, res_context_list, loss_scale_list, query=q, response=r, system=system, round0=i)
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, **kwargs)
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(res_context_list, loss_scale_list)

Expand Down Expand Up @@ -1151,6 +1156,9 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
Template(['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '], ['</s>'], LLAMA_DEFAULT_SYSTEM,
['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))

register_template(TemplateType.mistral_nemo,
Template(['<s>[INST] '], ['{{SYSTEM}}\n\n', '{{QUERY}} [/INST]'], ['[INST] '], ['</s>']))

register_template(
TemplateType.llama3,
Template(['<|begin_of_text|>'], [
Expand Down

0 comments on commit bfe509c

Please sign in to comment.