Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing eos_token when doing inference. #351

Merged
merged 6 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix template encode
  • Loading branch information
Jintao-Huang committed Jan 30, 2024
commit e2a542290423419e708fa6cf61d1aded133c2e72
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ from swift.llm import (
infer_main, sft_main, app_ui_main, merge_lora_main
)

model_type = ModelType.qwen_1_8b_chat
model_type = ModelType.qwen_1_8b
sft_args = SftArguments(
model_type=model_type,
train_dataset_sample=2000,
Expand All @@ -178,7 +178,7 @@ torch.cuda.empty_cache()
infer_args = InferArguments(
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
show_dataset_sample=10)
val_dataset_sample=10)
# merge_lora_main(infer_args)
result = infer_main(infer_args)
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ from swift.llm import (
infer_main, sft_main, app_ui_main, merge_lora_main
)

model_type = ModelType.qwen_1_8b_chat
model_type = ModelType.qwen_1_8b
sft_args = SftArguments(
model_type=model_type,
train_dataset_sample=2000,
Expand All @@ -178,7 +178,7 @@ torch.cuda.empty_cache()
infer_args = InferArguments(
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
show_dataset_sample=10)
val_dataset_sample=10)
# merge_lora_main(infer_args)
result = infer_main(infer_args)
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/LLM微调文档.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ torch.cuda.empty_cache()
infer_args = InferArguments(
ckpt_dir=best_model_checkpoint,
load_dataset_config=True,
show_dataset_sample=10)
val_dataset_sample=10)
# merge_lora_main(infer_args)
result = infer_main(infer_args)
torch.cuda.empty_cache()
Expand Down
37 changes: 18 additions & 19 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,8 @@ def __call__(self, input_ids: Tensor, scores: Tensor) -> bool:
if isinstance(stop_word, str):
if stop_word in text:
return True
elif isinstance(stop_word, list) and len(stop_word) > 0:
res = []
for sw in stop_word:
if isinstance(sw, str):
token = getattr(tokenizer, sw)
assert token is not None
else:
token = sw
res.append(token)
if input_ids[0].tolist()[-len(res):] == res:
else:
if input_ids[0].tolist()[-len(stop_word):] == stop_word:
return True
return False

Expand Down Expand Up @@ -148,6 +140,22 @@ def _init_template(self,
self.max_length = max_length
self.truncation_strategy = truncation_strategy
self.model = kwargs.get('model', None)
# e.g. [['eos_token_id']] -> [[2]]
for key in ['prefix', 'prompt', 'chat_sep', 'suffix']:
value = getattr(self, key)
if value is None:
continue
res_value = []
for v in value:
if isinstance(v, list):
res_v = []
for sub_v in v:
if isinstance(sub_v, str):
sub_v = getattr(tokenizer, sub_v)
res_v.append(sub_v)
v = res_v
res_value.append(v)
setattr(self, key, res_value)

def encode(
self, example: Dict[str,
Expand Down Expand Up @@ -253,15 +261,6 @@ def _encode_context_list(
return_attention_mask=False,
add_special_tokens=False,
**curr_tokenizer_kwargs)['input_ids']
else:
token_list = []
for c in context:
if isinstance(c, str):
token = getattr(tokenizer, c)
assert token is not None
else:
token = c
token_list.append(token)
input_ids += token_list
if i in compute_loss_idx:
labels += token_list
Expand Down
4 changes: 4 additions & 0 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ def inference_stream(model: PreTrainedModel,
stream_config.eos_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is not None:
stream_config.pad_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is not None:
stream_config.bos_token_id = tokenizer.bos_token_id
if stream_config.max_new_tokens is not None:
stream_config.max_length = 20 # fix max_length, max_new_tokens warning
stream_config.do_sample = True # avoid is_greedy_gen_mode = True
Expand Down Expand Up @@ -568,6 +570,8 @@ def inference(model: PreTrainedModel,
generation_config.eos_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is not None:
generation_config.pad_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is not None:
stream_config.bos_token_id = tokenizer.bos_token_id
if generation_config.max_new_tokens is not None:
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
if template.suffix[-1] not in stop_words:
Expand Down