forked from hiyouga/ChatGLM-Efficient-Tuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_sft.py
100 lines (84 loc) · 3.69 KB
/
train_sft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# coding=utf-8
# Implements several parameter-efficient supervised fine-tuning method for ChatGLM.
# This code is inspired by https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
from utils import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DataCollatorForChatGLM,
Seq2SeqTrainerForChatGLM,
ComputeMetrics,
LogCallback,
get_logits_processor,
plot_loss
)
def main():
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args, training_args.do_train, stage="sft")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForChatGLM(tokenizer, model, data_args.ignore_pad_token_for_loss)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = Seq2SeqTrainerForChatGLM(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
)
# Keyword arguments for `model.generate`
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_new_tokens": data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results, tokenizer)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()