Skip to content

Commit

Permalink
support multiple datasets, add Chinese readme
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Apr 11, 2023
1 parent 40758d1 commit f2820ef
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 124 deletions.
23 changes: 13 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

Fine-tuning 🤖[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) model with 🤗[PEFT](https://github.com/huggingface/peft).

\[ English | [中文](README_zh.md) \]

## Datasets

Our script now supports the following datasets:
Expand All @@ -24,6 +26,8 @@ Our script now supports the following datasets:

Please refer to `config_data.py` for details.

[23/04/11] Now we support training with combined datasets! Try `dataset1,dataset2` argument for training with multiple datasets.

## Fine-Tuning Methods

Our script now supports the following fine-tuning methods:
Expand All @@ -41,7 +45,7 @@ And **powerful GPUs**!

## Getting Started

### Preparation
### Preparation (optional)

```bash
git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git
Expand All @@ -55,8 +59,9 @@ pip install -r requirements.txt
```bash
CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--do_train \
--dataset guanaco \
--output_dir output_guanaco \
--dataset alpaca_gpt4_zh \
--finetuning_type lora \
--output_dir output \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 4 \
Expand All @@ -66,7 +71,7 @@ CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--save_steps 1000 \
--warmup_steps 100 \
--max_train_samples 10000 \
--learning_rate 5e-4 \
--learning_rate 5e-5 \
--num_train_epochs 1.0 \
--fp16
```
Expand All @@ -85,9 +90,10 @@ CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py


## Compared with Existing Implementations

- [THUDM/ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)
- Official implementation of fine-tuning ChatGLM with [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) on the [ADGEN](https://aclanthology.org/D19-1321.pdf) dataset.
- Our fine-tuning script is largely depend on it. We further implement the [LoRA](https://arxiv.org/abs/2106.09685) tuning method. Additionally, we **dynamically** pad the inputs to the longest sequence in the batch instead of the maximum length.
- Our fine-tuning script is largely depend on it. We further implement the [LoRA](https://arxiv.org/abs/2106.09685) tuning method. Additionally, we **dynamically** pad the inputs to the longest sequence in the batch instead of the maximum length, to accelerate the fine-tuning.
- [mymusise/ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning)
- An unoffical implementation of fine-tuning ChatGLM with [LoRA](https://arxiv.org/abs/2106.09685) on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset.
- We borrowed some ideas from it. Our fine-tuning script **integrates** the data pre-processing part into the training procedure, so we need not generate a pre-processed dataset before training.
Expand All @@ -104,11 +110,10 @@ CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
- An unofficial implementation of fine-tuning ChatGLM that explores the ChatGLM's ability on the instruction-following datasets.
- Our fine-tuning script integrates the data pre-processing part in to the training procedure.


## TODO

- Incorporating [Chinese datasets](https://github.com/brightmart/nlp_chinese_corpus) into the training sets.
- [BELLE](https://github.com/LianjiaTech/BELLE)
- ~~[BELLE](https://github.com/LianjiaTech/BELLE)~~
- [pCLUE](https://github.com/CLUEbenchmark/pCLUE)
- [CLUECorpus](https://github.com/CLUEbenchmark/CLUECorpus2020)
- ~~[GuanacoDataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)~~
Expand All @@ -118,13 +123,12 @@ CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
- ~~[GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)~~
- Implementing the Freeze-Tuning and ~~P-Tuning~~ method.
- Supporting Multi-GPUs fine-tuning.

- Add script for evaluation.

## License

This repository is licensed under the [Apache-2.0 License](LICENSE).


## Citation

If this work is helpful, please cite as:
Expand All @@ -138,7 +142,6 @@ If this work is helpful, please cite as:
}
```


## Acknowledgement

This repo benefits from [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B), [ChatGLM-Tuning](https://github.com/THUDM/ChatGLM-6B) and [yuanzhoulvpi2017/zero_nlp](https://github.com/yuanzhoulvpi2017/zero_nlp). Thanks for their wonderful works.
148 changes: 148 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# ChatGLM Efficient Tuning

![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/ChatGLM-Efficient-Tuning?style=social)
![GitHub Code License](https://img.shields.io/github/license/hiyouga/ChatGLM-Efficient-Tuning)
![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/ChatGLM-Efficient-Tuning)

基于 🤗[PEFT](https://github.com/huggingface/peft) 的高效 🤖[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 微调。

\[ [English](README.md) | 中文 \]

## 数据集

目前我们实现了针对以下数据集的支持:

- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)

使用方法请参考 `config_data.py`

[23/04/11] 现在我们实现了数据集组合训练!请尝试使用 `dataset1,dataset2` 参数进行组合训练。

## 微调方法

目前我们实现了针对以下高效微调方法的支持:

- [P-Tuning V2](https://github.com/THUDM/P-tuning-v2)
- [LoRA](https://arxiv.org/abs/2106.09685)

## 软件依赖

- Python 3.10, PyTorch 2.0.0
- 🤗Transformers, Datasets, PEFT
- protobuf, cpm_kernels, sentencepiece

以及 **强而有力的 GPU**

## 如何使用

### 环境搭建(可跳过)

```bash
git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git
conda create -n cet python=3.10
conda activate cet
pip install -r requirements.txt
```

### 微调训练

```bash
CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--do_train \
--dataset alpaca_gpt4_zh \
--finetuning_type lora \
--output_dir output \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--warmup_steps 100 \
--max_train_samples 10000 \
--learning_rate 5e-5 \
--num_train_epochs 1.0 \
--fp16
```

### 测试效果

```bash
CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
```

### 硬件需求

| 批处理大小 | LoRA `r` | 模式 | GPU显存 |
| --------- | -------- | ---- | ------ |
| 8 | 8 | FP16 | 24GB |


## 和现有类似项目的比较

- [THUDM/ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)
- ChatGLM 基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 微调的官方实现,使用了 [ADGEN](https://aclanthology.org/D19-1321.pdf) 数据集。
- 本仓库的代码实现绝大部分参考该项目。我们进一步实现了 [LoRA](https://arxiv.org/abs/2106.09685) 微调方法。此外,我们**动态地**将每个批处理数据中的序列进行填充,而非将其填充到模型的最大长度,此改进可以加速模型训练。
- [mymusise/ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning)
- ChatGLM 基于 [LoRA](https://arxiv.org/abs/2106.09685) 微调的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
- 我们借鉴了该项目的一些想法。我们的训练脚本将数据预处理部分**集成**至训练脚本中,以避免事先生成预处理后的数据。
- [ssbuild/chatglm_finetuning](https://github.com/ssbuild/chatglm_finetuning)
- ChatGLM 基于多种微调方法的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
- 我们的训练脚本**全部**基于 [Huggingface transformers](https://github.com/huggingface/transformers) 框架实现,不依赖于额外的 [deep_training](https://github.com/ssbuild/deep_training) 框架。
- [lich99/ChatGLM-finetune-LoRA](https://github.com/lich99/ChatGLM-finetune-LoRA)
- ChatGLM 基于 [LoRA](https://arxiv.org/abs/2106.09685) 微调的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
- 我们利用 [Huggingface PEFT](https://github.com/huggingface/peft) 框架来引入最先进的微调方法。
- [liucongg/ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning)
- ChatGLM 基于参数冻结、LoRA 和 P-Tuning 微调的非官方实现,使用了汽车工业数据集。
- 我们旨在引入更多指令遵循数据集用于微调 ChatGLM 模型。
- [yanqiangmiffy/InstructGLM](https://github.com/yanqiangmiffy/InstructGLM)
- ChatGLM 微调的非官方实现,旨在探索 ChatGLM 在指令遵循数据集上的潜力。
- 我们将数据预处理部分集成到训练脚本中。

## TODO

- 加入更多[中文数据集](https://github.com/brightmart/nlp_chinese_corpus)
- ~~[BELLE](https://github.com/LianjiaTech/BELLE)~~
- [pCLUE](https://github.com/CLUEbenchmark/pCLUE)
- [CLUECorpus](https://github.com/CLUEbenchmark/CLUECorpus2020)
- ~~[GuanacoDataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)~~
- ~~[FireflyDataset](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)~~
- 加入基于 [ChatGPT](https://openai.com/blog/chatgpt)[GPT-4](https://openai.com/research/gpt-4) 产生的数据集。
- [Baize](https://github.com/project-baize/baize-chatbot)
- ~~[GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)~~
- 实现参数冻结和 ~~P-Tuning~~ 微调方法。
- 支持多GPU训练。
- 加入模型评估脚本。

## 协议

本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。

## 引用

如果您觉得此项目有帮助,请考虑以下列格式引用

```bibtex
@Misc{chatglm-efficient-tuning,
title = {ChatGLM Efficient Tuning},
author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/ChatGLM-Efficient-Tuning}},
year = {2023}
}
```


## 声明

本项目受益于 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)[ChatGLM-Tuning](https://github.com/THUDM/ChatGLM-6B)[yuanzhoulvpi2017/zero_nlp](https://github.com/yuanzhoulvpi2017/zero_nlp),感谢作者的付出。
67 changes: 40 additions & 27 deletions arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ class ModelArguments:
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training.
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
dataset: Optional[str] = field(
default="alpaca_zh",
metadata={"help": "The name of provided dataset to use."}
metadata={"help": "The name of provided dataset to use. Use comma to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "The name of the folder containing datasets."}
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and sets."}
metadata={"help": "Overwrite the cached training and evaluation sets."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
Expand All @@ -80,7 +80,11 @@ class DataTrainingArguments:
)
max_train_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of training examples."}
metadata={"help": "For debugging purposes, truncate the number of training examples for each dataset."}
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of evaluation examples for each dataset."}
)
ignore_pad_token_for_loss: bool = field(
default=True,
Expand All @@ -91,31 +95,40 @@ class DataTrainingArguments:
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)

def __post_init__(self):
if self.dataset not in DATASETS:
raise ValueError("Undefined dataset in config_data.py.")

if "hf_hub_url" in DATASETS[self.dataset]:
self.load_from = "hf_hub"
self.dataset_name = DATASETS[self.dataset]["hf_hub_url"]
elif "script_url" in DATASETS[self.dataset]:
self.load_from = "script"
self.dataset_name = DATASETS[self.dataset]["script_url"]
def __post_init__(self): # support mixing multiple datasets
if "," in self.dataset:
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
else:
self.load_from = "file"
self.train_file = DATASETS[self.dataset]["filename"]
self.train_hash = DATASETS[self.dataset]["sha1"]
dataset_names = [self.dataset.strip()]

if "columns" in DATASETS[self.dataset]:
self.prompt_column = DATASETS[self.dataset]["columns"]["prompt"]
self.query_column = DATASETS[self.dataset]["columns"]["query"]
self.response_column = DATASETS[self.dataset]["columns"]["response"]
self.history_column = DATASETS[self.dataset]["columns"]["history"]
else:
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
self.dataset_list = []
for name in dataset_names:
if name not in DATASETS:
raise ValueError("Undefined dataset {} in config_data.py.".format(name))

dataset_info = {}
if "hf_hub_url" in DATASETS[name]:
dataset_info["load_from"] = "hf_hub"
dataset_info["dataset_name"] = DATASETS[name]["hf_hub_url"]
elif "script_url" in DATASETS[name]:
dataset_info["load_from"] = "script"
dataset_info["dataset_name"] = DATASETS[name]["script_url"]
else:
dataset_info["load_from"] = "file"
dataset_info["file_name"] = DATASETS[name]["file_name"]
dataset_info["file_sha1"] = DATASETS[name]["file_sha1"]

if "columns" in DATASETS[name]:
dataset_info["prompt_column"] = DATASETS[name]["columns"]["prompt"]
dataset_info["query_column"] = DATASETS[name]["columns"]["query"]
dataset_info["response_column"] = DATASETS[name]["columns"]["response"]
dataset_info["history_column"] = DATASETS[name]["columns"]["history"]
else:
dataset_info["prompt_column"] = "instruction"
dataset_info["query_column"] = "input"
dataset_info["response_column"] = "output"
dataset_info["history_column"] = None
self.dataset_list.append(dataset_info)


@dataclass
Expand Down
8 changes: 4 additions & 4 deletions config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"dataset_name": {
"hf_hub_url": the name of the dataset repository on the HF hub. (if specified, ignore below 3 arguments)
"script_url": the name of the script in the local `dataset_dir` directory. (if specified, ignore below 2 arguments)
"filename": the name of the dataset file in the local `dataset_dir` directory. (required if hf_hub_url not specified)
"sha1": the SHA-1 hash value of the dataset file. (required if hf_hub_url not specified)
"file_name": the name of the dataset file in the local `dataset_dir` directory. (required if hf_hub_url not specified)
"file_sha1": the SHA-1 hash value of the dataset file. (required if hf_hub_url not specified)
"columns": { (optional, if not provided, use the default values)
"prompt": the name of the column in the datasets containing the prompts. (default: instruction)
"query": the name of the column in the datasets containing the queries. (default: input)
Expand All @@ -20,8 +20,8 @@
DATASETS = {
"alpaca_en": {"hf_hub_url": "tatsu-lab/alpaca"},
"alpaca_zh": {
"filename": "alpaca_data_zh_51k.json",
"sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
"file_name": "alpaca_data_zh_51k.json",
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
},
"alpaca_gpt4_en": {"hf_hub_url": "c-s-ale/alpaca-gpt4-data"},
"alpaca_gpt4_zh": {"hf_hub_url": "c-s-ale/alpaca-gpt4-data-zh"},
Expand Down
6 changes: 3 additions & 3 deletions finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--do_train \
--dataset guanaco \
--dataset alpaca_gpt4_zh \
--finetuning_type lora \
--output_dir output \
--overwrite_cache \
--overwrite_output_dir \
Expand All @@ -13,7 +14,6 @@ CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--save_steps 1000 \
--warmup_steps 100 \
--max_train_samples 10000 \
--learning_rate 5e-4 \
--learning_rate 5e-5 \
--num_train_epochs 1.0 \
--finetuning_type lora \
--fp16
4 changes: 2 additions & 2 deletions finetune_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
def main():

model_args, data_args, training_args, finetuning_args = prepare_args()
raw_datasets = prepare_data(model_args, data_args)
dataset = prepare_data(model_args, data_args, training_args)
tokenizer, model = prepare_model(model_args, finetuning_args)
dataset = preprocess_data(raw_datasets, tokenizer, data_args, training_args)
dataset = preprocess_data(dataset, tokenizer, data_args, training_args)
data_collator = DataCollatorForChatGLM(tokenizer=tokenizer, data_args=data_args)
trainer = TrainerForChatGLM(
model=model,
Expand Down
Loading

0 comments on commit f2820ef

Please sign in to comment.