Skip to content

Commit

Permalink
Removing eos_token when doing inference. (modelscope#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jan 30, 2024
1 parent 6a19a54 commit 1dea2fc
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 30 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用


## 🎉 News
- 2024.1.26: Support [yi-vl-6b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_vl_6b_chat), yi-vl-34b-chat.
- 2024.1.29: Support internlm2-math series: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
- 🔥2024.1.26: Support [yi-vl-6b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_vl_6b_chat), yi-vl-34b-chat.
- 2024.1.24: Support codefuse-codegeex2-6b-chat, codefuse-qwen-14b-chat.
- 2024.1.23: Support orion series: orion-14b, [orion-14b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/orion_14b_chat).
- 2024.1.20: Support [xverse-13b-256k](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/xverse_13b_256k), xverse-65b-v2, xverse-65b-chat.
Expand Down Expand Up @@ -164,7 +165,7 @@ from swift.llm import (
infer_main, sft_main, app_ui_main, merge_lora_main
)

model_type = ModelType.qwen_1_8b_chat
model_type = ModelType.qwen_1_8b
sft_args = SftArguments(
model_type=model_type,
train_dataset_sample=2000,
Expand All @@ -178,7 +179,7 @@ torch.cuda.empty_cache()
infer_args = InferArguments(
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
show_dataset_sample=10)
val_dataset_sample=10)
# merge_lora_main(infer_args)
result = infer_main(infer_args)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -222,6 +223,8 @@ app_ui_main(infer_args)
- [deepseek-coder](https://github.com/deepseek-ai/DeepSeek-Coder) series: deepseek-coder-1_3b, deepseek-coder-1_3b-instruct, deepseek-coder-6_7b, deepseek-coder-6_7b-instruct, deepseek-coder-33b, deepseek-coder-33b-instruct.
- [codegeex2](https://github.com/THUDM/CodeGeeX2) series: codegeex2-6b.
- [phi](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) series: phi2-3b.
- Math:
- [internlm2-math](https://github.com/InternLM/InternLM-Math) series: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
- Supported Datasets: [[Detailed Info]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
- NLP:
- General: 🔥alpaca-en(gpt4), 🔥alpaca-zh(gpt4), multi-alpaca-all, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, instruct-en, gpt4all-en, sharegpt-en, sharegpt-zh, tutu-v2-sft-mixture, wikipedia-zh, open-orca, open-orca-gpt4, sharegpt-gpt4.
Expand Down
7 changes: 5 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。

## 🎉 新闻
- 2024.1.29: 支持internlm2-math系列: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
- 2024.1.26: 支持[yi-vl-6b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_vl_6b_chat), yi-vl-34b-chat.
- 2024.1.24: 支持codefuse-codegeex2-6b-chat, codefuse-qwen-14b-chat.
- 2024.1.23: 支持orion系列: orion-14b, [orion-14b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/orion_14b_chat).
Expand Down Expand Up @@ -164,7 +165,7 @@ from swift.llm import (
infer_main, sft_main, app_ui_main, merge_lora_main
)

model_type = ModelType.qwen_1_8b_chat
model_type = ModelType.qwen_1_8b
sft_args = SftArguments(
model_type=model_type,
train_dataset_sample=2000,
Expand All @@ -178,7 +179,7 @@ torch.cuda.empty_cache()
infer_args = InferArguments(
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
show_dataset_sample=10)
val_dataset_sample=10)
# merge_lora_main(infer_args)
result = infer_main(infer_args)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -222,6 +223,8 @@ app_ui_main(infer_args)
- [deepseek-coder](https://github.com/deepseek-ai/DeepSeek-Coder) 系列: deepseek-coder-1_3b, deepseek-coder-1_3b-instruct, deepseek-coder-6_7b, deepseek-coder-6_7b-instruct, deepseek-coder-33b, deepseek-coder-33b-instruct.
- [codegeex2](https://github.com/THUDM/CodeGeeX2) 系列: codegeex2-6b.
- [phi](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) 系列: phi2-3b.
- 数学:
- [internlm2-math](https://github.com/InternLM/InternLM-Math) 系列: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
- 支持的数据集: [[详细信息]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
- NLP:
- 通用: 🔥alpaca-en(gpt4), 🔥alpaca-zh(gpt4), multi-alpaca-all, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, instruct-en, gpt4all-en, sharegpt-en, sharegpt-zh, tutu-v2-sft-mixture, wikipedia-zh, open-orca, open-orca-gpt4, sharegpt-gpt4.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/LLM微调文档.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ torch.cuda.empty_cache()
infer_args = InferArguments(
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
show_dataset_sample=10)
val_dataset_sample=10)
# merge_lora_main(infer_args)
result = infer_main(infer_args)
torch.cuda.empty_cache()
Expand Down
4 changes: 4 additions & 0 deletions docs/source/LLM/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
|internlm2-20b|[Shanghai_AI_Laboratory/internlm2-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-20b/summary)|wqkv|default-generation-bos|✔|✘||
|internlm2-20b-sft-chat|[Shanghai_AI_Laboratory/internlm2-chat-20b-sft](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-20b-sft/summary)|wqkv|internlm2|✔|✘||
|internlm2-20b-chat|[Shanghai_AI_Laboratory/internlm2-chat-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-20b/summary)|wqkv|internlm2|✔|✘||
|internlm2-math-7b|[Shanghai_AI_Laboratory/internlm2-math-base-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-base-7b/summary)|wqkv|default-generation-bos|✔|✘||
|internlm2-math-7b-chat|[Shanghai_AI_Laboratory/internlm2-math-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-7b/summary)|wqkv|internlm2|✔|✘||
|internlm2-math-20b|[Shanghai_AI_Laboratory/internlm2-math-base-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-base-20b/summary)|wqkv|default-generation-bos|✔|✘||
|internlm2-math-20b-chat|[Shanghai_AI_Laboratory/internlm2-math-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-20b/summary)|wqkv|internlm2|✔|✘||
|deepseek-7b|[deepseek-ai/deepseek-llm-7b-base](https://modelscope.cn/models/deepseek-ai/deepseek-llm-7b-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✔||
|deepseek-7b-chat|[deepseek-ai/deepseek-llm-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-llm-7b-chat/summary)|q_proj, k_proj, v_proj|deepseek|✔|✔||
|deepseek-moe-16b|[deepseek-ai/deepseek-moe-16b-base](https://modelscope.cn/models/deepseek-ai/deepseek-moe-16b-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✘||
Expand Down
9 changes: 8 additions & 1 deletion scripts/utils/run_model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def get_model_info_readme_zh(data: List[str]) -> None:
model_list = []
for match in match_list:
model_list += match[2].strip('.').split(',')
model_list = [model.strip() for model in model_list]
model_list_2 = []
for model in model_list:
model = model.strip()
model_match = re.search(r'\[(.+)\]\(.+\)', model)
if model_match is not None:
model = model_match.group(1)
model_list_2.append(model)
model_list = model_list_2
model_type_list = [d[0] for d in data]
print(set(model_type_list) - set(model_list))
print(set(model_list) - set(model_type_list))
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ class ModelType:
internlm2_20b_sft_chat = 'internlm2-20b-sft-chat'
internlm2_20b_chat = 'internlm2-20b-chat'
# internlm2-math
internlm2_math_7b_chat = 'internlm2-math-7b-chat'
internlm2_math_7b = 'internlm2-math-7b'
internlm2_math_20b_chat = 'internlm2-math-20b-chat'
internlm2_math_7b_chat = 'internlm2-math-7b-chat'
internlm2_math_20b = 'internlm2-math-20b'
internlm2_math_20b_chat = 'internlm2-math-20b-chat'
# deepseek
deepseek_7b = 'deepseek-7b'
deepseek_7b_chat = 'deepseek-7b-chat'
Expand Down
44 changes: 26 additions & 18 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,9 @@ def __call__(self, input_ids: Tensor, scores: Tensor) -> bool:
if isinstance(stop_word, str):
if stop_word in text:
return True
elif isinstance(stop_word, list) and len(stop_word) > 0:
res = []
for sw in stop_word:
if isinstance(sw, str):
token = getattr(tokenizer, sw)
assert token is not None
else:
token = sw
res.append(token)
if input_ids[0].tolist()[-len(res):] == res:
else: # list
if len(stop_word) > 0 and input_ids[0].tolist(
)[-len(stop_word):] == stop_word:
return True
return False

Expand Down Expand Up @@ -132,6 +125,24 @@ def __init__(self,
self.use_default_system = True
self._is_init = False

@staticmethod
def _preprocess_prompt(tokenizer: PreTrainedTokenizerBase,
value: Optional[Prompt]) -> Optional[Prompt]:
# e.g. [['eos_token_id']] -> [[2]]
if value is None:
return None
res_value = []
for v in value:
if isinstance(v, list):
res_v = []
for sub_v in v:
if isinstance(sub_v, str):
sub_v = getattr(tokenizer, sub_v)
res_v.append(sub_v)
v = res_v
res_value.append(v)
return res_value

def _init_template(self,
tokenizer: PreTrainedTokenizerBase,
default_system: Optional[str] = None,
Expand All @@ -148,6 +159,10 @@ def _init_template(self,
self.max_length = max_length
self.truncation_strategy = truncation_strategy
self.model = kwargs.get('model', None)
for key in ['prefix', 'prompt', 'chat_sep', 'suffix']:
value = getattr(self, key)
value = self._preprocess_prompt(tokenizer, value)
setattr(self, key, value)

def encode(
self, example: Dict[str,
Expand Down Expand Up @@ -254,14 +269,7 @@ def _encode_context_list(
add_special_tokens=False,
**curr_tokenizer_kwargs)['input_ids']
else:
token_list = []
for c in context:
if isinstance(c, str):
token = getattr(tokenizer, c)
assert token is not None
else:
token = c
token_list.append(token)
token_list = context
input_ids += token_list
if i in compute_loss_idx:
labels += token_list
Expand Down
4 changes: 4 additions & 0 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ def inference_stream(model: PreTrainedModel,
stream_config.eos_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is not None:
stream_config.pad_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is not None:
stream_config.bos_token_id = tokenizer.bos_token_id
if stream_config.max_new_tokens is not None:
stream_config.max_length = 20 # fix max_length, max_new_tokens warning
stream_config.do_sample = True # avoid is_greedy_gen_mode = True
Expand Down Expand Up @@ -568,6 +570,8 @@ def inference(model: PreTrainedModel,
generation_config.eos_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is not None:
generation_config.pad_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is not None:
generation_config.bos_token_id = tokenizer.bos_token_id
if generation_config.max_new_tokens is not None:
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
if template.suffix[-1] not in stop_words:
Expand Down
12 changes: 9 additions & 3 deletions tests/llm/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def test_cogagent_instruct(self):
torch.cuda.empty_cache()
infer_main(
InferArguments(
ckpt_dir=best_model_checkpoint, load_dataset_config=True))
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
val_dataset_sample=2))

def test_yi_vl_6b_chat(self):
if not __name__ == '__main__':
Expand All @@ -272,7 +274,9 @@ def test_yi_vl_6b_chat(self):
torch.cuda.empty_cache()
infer_main(
InferArguments(
ckpt_dir=best_model_checkpoint, load_dataset_config=True))
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
val_dataset_sample=2))

def test_dpo(self):
if not __name__ == '__main__':
Expand All @@ -288,7 +292,9 @@ def test_dpo(self):
torch.cuda.empty_cache()
infer_main(
InferArguments(
ckpt_dir=best_model_checkpoint, load_dataset_config=True))
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
val_dataset_sample=2))


def data_collate_fn(batch: List[Dict[str, Any]],
Expand Down

0 comments on commit 1dea2fc

Please sign in to comment.