Skip to content

Commit

Permalink
support evaluation while training
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 19, 2023
1 parent 2dbba2f commit 9c57a2a
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Fine-tuning 🤖[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) model with

## Changelog

[23/05/19] Now we support using the development set to evaluate the model while training. Try `--dev_ratio` argument to specify the size of development set.

[23/04/29] Now we support training ChatGLM with **Reinforcement Learning with Human Feedback (RLHF)** ! We provide several examples to run RLHF training, please refer to the `examples` folder for details. (experimental feature)

[23/04/20] Our repo achieved 100 stars within 12 days! Congratulations!
Expand Down
2 changes: 2 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

## 更新日志

[23/05/19] 现在我们支持了在模型训练时使用验证集评估性能。请尝试使用 `--dev_ratio` 参数指定验证集大小。

[23/04/29] 现在我们实现了 **RLHF(基于人类反馈的强化学习)** 训练!我们提供了几个运行 RLHF 的例子,具体内容请移步 `examples` 文件夹。(实验性功能)

[23/04/25] 我们新增了一个使用自定义数据集分布式训练的例子,请移步 [ads_generation.md](examples/ads_generation.md) 查阅。
Expand Down
1 change: 1 addition & 0 deletions examples/quantized_finetune_with_local_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ CUDA_VISIBLE_DEVICES=0 python ../src/train_sft.py \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--quantization_bit 8 \
--plot_loss \
--fp16
1 change: 1 addition & 0 deletions examples/train_ppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ CUDA_VISIBLE_DEVICES=0 python ../src/train_ppo.py \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
1 change: 1 addition & 0 deletions examples/train_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ CUDA_VISIBLE_DEVICES=0 python ../src/train_rm.py \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
1 change: 1 addition & 0 deletions examples/train_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ CUDA_VISIBLE_DEVICES=0 python ../src/train_sft.py \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
23 changes: 23 additions & 0 deletions examples/train_sft_with_dev_set.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python ../src/train_sft.py \
--do_train \
--dataset alpaca_gpt4_zh \
--dataset_dir ../data \
--finetuning_type lora \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--dev_ratio 0.01 \
--evaluation_strategy steps \
--eval_steps 100 \
--load_best_model_at_end \
--plot_loss \
--fp16
21 changes: 18 additions & 3 deletions src/train_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,24 @@ def main():

training_args.remove_unused_columns = False # Important for pairwise dataset

# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}

# Initialize our Trainer
trainer = PairwiseTrainerForChatGLM(
finetuning_args=finetuning_args,
model=model,
args=training_args,
train_dataset=dataset if training_args.do_train else None,
eval_dataset=dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator
data_collator=data_collator,
**trainer_kwargs
)

# Training
Expand All @@ -47,6 +56,12 @@ def main():
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args)

# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down
15 changes: 12 additions & 3 deletions src/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@ def main():
training_args.generation_num_beams = data_args.num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams

# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}

# Initialize our Trainer
trainer = Seq2SeqTrainerForChatGLM(
finetuning_args=finetuning_args,
model=model,
args=training_args,
train_dataset=dataset if training_args.do_train else None,
eval_dataset=dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
)

# Keyword arguments for `model.generate`
Expand Down
10 changes: 5 additions & 5 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def prepare_args() -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTraini
transformers.utils.logging.enable_explicit_format()

# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1:
raise ValueError("We must perform a single operation among do_train, do_eval and do_predict.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set to True while training.")

if model_args.quantization_bit is not None and training_args.do_train == False:
logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.")
Expand Down Expand Up @@ -469,10 +469,10 @@ def print_ppo_dataset_example(example):
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))

if stage == "sft":
if training_args.do_train:
preprocess_function = preprocess_supervised_dataset
else:
if (not training_args.do_train) and training_args.predict_with_generate:
preprocess_function = preprocess_evaluation_dataset
else:
preprocess_function = preprocess_supervised_dataset
elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
Expand Down
4 changes: 4 additions & 0 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
dev_ratio: Optional[float] = field(
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
)

def __post_init__(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
Expand Down
2 changes: 1 addition & 1 deletion src/utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def generate(
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()

unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
unwrapped_model = self.accelerator.unwrap_model(self.model)

response = unwrapped_model.generate(**inputs, **generation_kwargs)

Expand Down

0 comments on commit 9c57a2a

Please sign in to comment.