Skip to content

Commit

Permalink
add checkpoint and Freeze method
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Apr 12, 2023
1 parent 754d900 commit 5a90d3e
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 190 deletions.
46 changes: 30 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ Fine-tuning 🤖[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) model with

\[ English | [中文](README_zh.md) \]

## Change Log

[23/04/12] Now we support training from checkpoints! Use `--checkpoint_dir` to specify the checkpoint model to fine-tune from.

[23/04/11] Now we support training with combined datasets! Try `dataset1,dataset2` argument for training with multiple datasets.

## Datasets

Our script now supports the following datasets:
Expand All @@ -26,14 +32,16 @@ Our script now supports the following datasets:

Please refer to `config_data.py` for details.

[23/04/11] Now we support training with combined datasets! Try `dataset1,dataset2` argument for training with multiple datasets.

## Fine-Tuning Methods

Our script now supports the following fine-tuning methods:

- [P-Tuning V2](https://github.com/THUDM/P-tuning-v2)
- We fine-tune the prefix encoder of the model.
- [LoRA](https://arxiv.org/abs/2106.09685)
- We fine-tune the model with the low-rank adapters.
- [Freeze](https://arxiv.org/abs/2012.14913)
- We fine-tune the MLPs in the last n blocks.

## Requirement

Expand All @@ -58,19 +66,16 @@ pip install -r requirements.txt
### Fine-tuning

```bash
CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
python finetune_chatglm.py \
--do_train \
--dataset alpaca_zh \
--dataset alpaca_gpt4_zh \
--finetuning_type lora \
--output_dir output \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--warmup_steps 100 \
--max_train_samples 10000 \
--learning_rate 5e-5 \
--num_train_epochs 1.0 \
Expand All @@ -80,29 +85,37 @@ CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
### Evaluation (BLEU and ROUGE_CHINESE)

```bash
CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
python finetune_chatglm.py \
--do_eval \
--dataset alpaca_zh \
--checkpoint_dir output \
--output_dir eval \
--overwrite_cache \
--overwrite_output_dir \
--per_device_eval_batch_size 1 \
--max_eval_samples 20 \
--max_eval_samples 50 \
--predict_with_generate
```

### Inference

```bash
CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
python infer_chatglm.py --checkpoint_dir output
```

### Hardware Requirements

| Batch size | LoRA `r` | Mode | GRAM |
| ---------- | -------- | ---- | ---- |
| 8 | 8 | FP16 | 24GB |
| 微调方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
| ---------------- | ---------- | ---- | ------ | ----- |
| LoRA (r=8) | 8 | FP16 | 20GB | 7ex/s |
| LoRA (r=8) | 16 | FP16 | 26GB | 8ex/s |
| P-Tuning (p=8) | 8 | FP16 | 24GB | 8ex/s |
| Freeze (l=2) | 2 | FP16 | 32GB | 4ex/s |

<sub>
r: lora rank,
p: number of prefix tokens,
l: number of trainable layers,
ex/s: examples per second
</sub>

## Compared with Existing Implementations

Expand Down Expand Up @@ -136,9 +149,10 @@ CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
- Incorporating [ChatGPT](https://openai.com/blog/chatgpt) & [GPT-4](https://openai.com/research/gpt-4) self-chat data into the training sets.
- [Baize](https://github.com/project-baize/baize-chatbot)
- ~~[GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)~~
- Implementing the Freeze-Tuning and ~~P-Tuning~~ method.
- ~~Implementing the Freeze-Tuning and P-Tuning method.~~
- Supporting Multi-GPUs fine-tuning.
- ~~Add script for evaluation.~~ (but it appears very slow)
- ~~Load from checkpoint.~~

## License

Expand Down
44 changes: 28 additions & 16 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

\[ [English](README.md) | 中文 \]

## 更新日志

[23/04/12] 现在我们加入了断点训练支持!请尝试给定 `--checkpoint_dir` 参数加载指定的模型断点。

[23/04/11] 现在我们实现了数据集组合训练!请尝试使用 `dataset1,dataset2` 参数进行组合训练。

## 数据集

目前我们实现了针对以下数据集的支持:
Expand All @@ -26,14 +32,16 @@

使用方法请参考 `config_data.py`

[23/04/11] 现在我们实现了数据集组合训练!请尝试使用 `dataset1,dataset2` 参数进行组合训练。

## 微调方法

目前我们实现了针对以下高效微调方法的支持:

- [P-Tuning V2](https://github.com/THUDM/P-tuning-v2)
- 仅微调前缀编码器。
- [LoRA](https://arxiv.org/abs/2106.09685)
- 仅微调低秩适应器。
- [Freeze](https://arxiv.org/abs/2012.14913)
- 仅微调后几层的全连接层。

## 软件依赖

Expand All @@ -58,19 +66,16 @@ pip install -r requirements.txt
### 微调训练

```bash
CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
python finetune_chatglm.py \
--do_train \
--dataset alpaca_zh \
--dataset alpaca_gpt4_zh \
--finetuning_type lora \
--output_dir output \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--warmup_steps 100 \
--max_train_samples 10000 \
--learning_rate 5e-5 \
--num_train_epochs 1.0 \
Expand All @@ -80,28 +85,34 @@ CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
### 指标评估(BLEU分数和汉语ROUGE分数)

```bash
CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
python finetune_chatglm.py \
--do_eval \
--dataset alpaca_zh \
--checkpoint_dir output \
--output_dir eval \
--overwrite_cache \
--overwrite_output_dir \
--per_device_eval_batch_size 1 \
--max_eval_samples 20 \
--max_eval_samples 50 \
--predict_with_generate
```

### 效果测试

```bash
CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
python infer_chatglm.py --checkpoint_dir output
```

### 硬件需求

| 批处理大小 | LoRA `r` | 模式 | GPU显存 |
| --------- | -------- | ---- | ------ |
| 8 | 8 | FP16 | 24GB |
| 微调方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
| ---------------- | ---------- | ---- | ------ | ----- |
| LoRA (r=8) | 8 | FP16 | 20GB | 7ex/s |
| LoRA (r=8) | 16 | FP16 | 26GB | 8ex/s |
| P-Tuning (p=8) | 8 | FP16 | 24GB | 8ex/s |
| Freeze (l=2) | 2 | FP16 | 32GB | 4ex/s |

<sub>
r:LoRA 维数大小,p:前缀词表大小,l:微调层数,ex/s:每秒训练的样本数
</sub>


## 和现有类似项目的比较
Expand Down Expand Up @@ -136,9 +147,10 @@ CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
- 加入基于 [ChatGPT](https://openai.com/blog/chatgpt)[GPT-4](https://openai.com/research/gpt-4) 产生的数据集。
- [Baize](https://github.com/project-baize/baize-chatbot)
- ~~[GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)~~
- 实现参数冻结和 ~~P-Tuning~~ 微调方法。
- ~~实现参数冻结和 P-Tuning 微调方法。~~
- 支持多GPU训练。
- ~~加入模型评估脚本。~~(但它可能很慢!)
- ~~断点加载。~~

## 协议

Expand Down
63 changes: 41 additions & 22 deletions arguments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from typing import Optional
from dataclasses import dataclass, field
from config_data import CHATGLM_LASTEST_HASH, DATASETS
from config_data import CHATGLM_REPO_NAME, CHATGLM_LASTEST_HASH, DATASETS


@dataclass
class DatasetInfo:

load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
file_sha1: Optional[str] = None

def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None


@dataclass
Expand All @@ -9,7 +24,7 @@ class ModelArguments:
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: Optional[str] = field(
default="THUDM/chatglm-6b",
default=CHATGLM_REPO_NAME,
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
)
config_name: Optional[str] = field(
Expand Down Expand Up @@ -41,7 +56,12 @@ class ModelArguments:
metadata={"help": "Whether to resize the position embeddings if `max_source_length` exceeds."}
)
quantization_bit: Optional[int] = field(
default=None
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
)


Expand Down Expand Up @@ -100,7 +120,7 @@ class DataTrainingArguments:
)

def __post_init__(self): # support mixing multiple datasets
if "," in self.dataset:
if self.dataset.find(",") != -1:
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
else:
dataset_names = [self.dataset.strip()]
Expand All @@ -110,28 +130,23 @@ def __post_init__(self): # support mixing multiple datasets
if name not in DATASETS:
raise ValueError("Undefined dataset {} in config_data.py.".format(name))

dataset_info = {}
if "hf_hub_url" in DATASETS[name]:
dataset_info["load_from"] = "hf_hub"
dataset_info["dataset_name"] = DATASETS[name]["hf_hub_url"]
dataset_info = DatasetInfo("hf_hub", dataset_name=DATASETS[name]["hf_hub_url"])
elif "script_url" in DATASETS[name]:
dataset_info["load_from"] = "script"
dataset_info["dataset_name"] = DATASETS[name]["script_url"]
dataset_info = DatasetInfo("script", dataset_name=DATASETS[name]["script_url"])
else:
dataset_info["load_from"] = "file"
dataset_info["file_name"] = DATASETS[name]["file_name"]
dataset_info["file_sha1"] = DATASETS[name]["file_sha1"]
dataset_info = DatasetInfo(
"file",
file_name=DATASETS[name]["file_name"],
file_sha1=DATASETS[name]["file_sha1"]
)

if "columns" in DATASETS[name]:
dataset_info["prompt_column"] = DATASETS[name]["columns"]["prompt"]
dataset_info["query_column"] = DATASETS[name]["columns"]["query"]
dataset_info["response_column"] = DATASETS[name]["columns"]["response"]
dataset_info["history_column"] = DATASETS[name]["columns"]["history"]
else:
dataset_info["prompt_column"] = "instruction"
dataset_info["query_column"] = "input"
dataset_info["response_column"] = "output"
dataset_info["history_column"] = None
dataset_info.prompt_column = DATASETS[name]["columns"]["prompt"]
dataset_info.query_column = DATASETS[name]["columns"]["query"]
dataset_info.response_column = DATASETS[name]["columns"]["response"]
dataset_info.history_column = DATASETS[name]["columns"]["history"]

self.dataset_list.append(dataset_info)


Expand All @@ -144,6 +159,10 @@ class FinetuningArguments:
default="lora",
metadata={"help": "The name of fine-tuning technique."}
)
num_layer_trainable: Optional[int] = field(
default=2,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
)
pre_seq_len: Optional[int] = field(
default=8,
metadata={"help": "Number of prefix tokens to use for P-tuning v2."}
Expand All @@ -166,5 +185,5 @@ class FinetuningArguments:
)

def __post_init__(self):
if self.finetuning_type not in ["freeze", "p_tuning", "lora"]:
if self.finetuning_type not in ["none", "freeze", "p_tuning", "lora"]:
raise NotImplementedError("Invalid fine-tuning method.")
3 changes: 2 additions & 1 deletion config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
}
"""

CHATGLM_LASTEST_HASH = 'aa51e62ddc9c9f334858b0af44cf59b05c70148a'
CHATGLM_REPO_NAME = "THUDM/chatglm-6b"
CHATGLM_LASTEST_HASH = "aa51e62ddc9c9f334858b0af44cf59b05c70148a"
DATASETS = {
"alpaca_en": {"hf_hub_url": "tatsu-lab/alpaca"},
"alpaca_zh": {
Expand Down
3 changes: 1 addition & 2 deletions evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--do_eval \
--dataset alpaca_zh \
--dataset alpaca_gpt4_zh \
--output_dir eval \
--overwrite_cache \
--overwrite_output_dir \
--per_device_eval_batch_size 1 \
--max_eval_samples 20 \
--predict_with_generate
3 changes: 1 addition & 2 deletions finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--do_train \
--dataset alpaca_zh \
--dataset alpaca_gpt4_zh \
--finetuning_type lora \
--output_dir output \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
Expand Down
Loading

0 comments on commit 5a90d3e

Please sign in to comment.