Skip to content

Commit

Permalink
fix bleu score, save custom code
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 4, 2023
1 parent 3a53dd4 commit 99487f2
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 10 deletions.
1 change: 0 additions & 1 deletion src/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def main():
model.save_pretrained(training_args.output_dir, max_shard_size="1GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
print("Remember to copy the *.py files from the original directory.")


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ def load_pretrained(
# Load and prepare pretrained models (without valuehead).
model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs)

# Register auto class to save the custom code files.
config.__class__.register_for_auto_class()
tokenizer.__class__.register_for_auto_class()
model.__class__.register_for_auto_class()

if model_args.use_v2:
assert tokenizer.eos_token_id is not None, "Please update the *.json and *.py files of ChatGLM2-6B from HuggingFace."
model.lm_head = model.transformer.output_layer
Expand Down
2 changes: 1 addition & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class FinetuningArguments:
metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
)
pre_seq_len: Optional[int] = field(
default=16,
default=64,
metadata={"help": "Number of prefix tokens to use for P-tuning V2."}
)
prefix_projection: Optional[bool] = field(
Expand Down
2 changes: 2 additions & 0 deletions src/utils/peft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str,
if self.finetuning_args.finetuning_type == "lora":
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model),
safe_serialization=self.args.save_safetensors
)
backbone_model.config.use_cache = False
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

Expand Down
18 changes: 10 additions & 8 deletions src/utils/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)

for pred, label in zip(preds, labels):
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

for pred, label in zip(decoded_preds, decoded_labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))

if len(" ".join(hypothesis).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
Expand Down Expand Up @@ -103,12 +106,11 @@ def save_predictions(
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)

decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(preds, labels):
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
label = self.tokenizer.decode(label, skip_special_tokens=True)

for pred, label in zip(decoded_preds, decoded_labels):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))

writer.write("\n".join(res))

0 comments on commit 99487f2

Please sign in to comment.