Skip to content

Commit

Permalink
add p-tuning v2 method
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Apr 10, 2023
1 parent d4c3ce2 commit 94da192
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 26 deletions.
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
![GitHub Code License](https://img.shields.io/github/license/hiyouga/ChatGLM-Efficient-Tuning)
![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/ChatGLM-Efficient-Tuning)


Fine-tuning 🤖[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) model with 🤗[PEFT](https://github.com/huggingface/peft).


## Datasets

Now our script supports the following datasets:
Our script now supports the following datasets:

- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
Expand All @@ -26,6 +24,13 @@ Now our script supports the following datasets:

Please refer to `config_data.py` for details.

## Fine-Tuning Methods

Our script now supports the following fine-tuning methods:

- [P-Tuning V2](https://github.com/THUDM/P-tuning-v2)
- [LoRA](https://arxiv.org/abs/2106.09685)

## Requirement

- Python 3.10 and PyTorch 2.0.0
Expand Down Expand Up @@ -111,7 +116,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_chatglm.py
- Incorporating [ChatGPT](https://openai.com/blog/chatgpt) & [GPT-4](https://openai.com/research/gpt-4) self-chat data into the training sets.
- [Baize](https://github.com/project-baize/baize-chatbot)
- ~~[GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)~~
- Implementing the Freeze-Tuning and P-Tuning method.
- Implementing the Freeze-Tuning and ~~P-Tuning~~ method.
- Supporting Multi-GPUs fine-tuning.


Expand All @@ -125,10 +130,10 @@ This repository is licensed under the [Apache-2.0 License](LICENSE).
If this work is helpful, please cite as:

```bibtex
@Misc{cet,
@Misc{chatglm-efficient-tuning,
title = {ChatGLM Efficient Tuning},
author = {hiyouga},
howpublished = {\url{https://github.com/huggingface/peft}},
howpublished = {\url{https://github.com/hiyouga/ChatGLM-Efficient-Tuning}},
year = {2023}
}
```
Expand Down
8 changes: 4 additions & 4 deletions arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ class FinetuningArguments:
metadata={"help": "The name of fine-tuning technique."}
)
pre_seq_len: Optional[int] = field(
default=None,
metadata={"help": "Number of tokens to use for p-tuning."}
default=8,
metadata={"help": "Number of prefix tokens to use for P-tuning v2."}
)
prefix_projection: bool = field(
default=False,
metadata={"help": "Whether to add a project layer for the prefix in p-tuning."}
metadata={"help": "Whether to add a project layer for the prefix in P-tuning v2."}
)
lora_rank: Optional[int] = field(
default=8,
Expand All @@ -146,5 +146,5 @@ class FinetuningArguments:
)

def __post_init__(self):
if self.finetuning_type not in ["freeze", "p-tuning", "lora"]:
if self.finetuning_type not in ["freeze", "p_tuning", "lora"]:
raise NotImplementedError("Invalid fine-tuning method.")
3 changes: 2 additions & 1 deletion finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--do_train \
--dataset guanaco \
--output_dir output_guanaco \
--output_dir output \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 4 \
Expand All @@ -15,4 +15,5 @@ CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--max_train_samples 10000 \
--learning_rate 5e-4 \
--num_train_epochs 1.0 \
--finetuning_type lora \
--fp16
9 changes: 8 additions & 1 deletion finetune_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
# [4] https://github.com/yuanzhoulvpi2017/zero_nlp/blob/main/Chatglm6b_ModelParallel_ptuning/main.py


import os
import torch
import logging
from utils import (
prepare_args,
prepare_data,
prepare_model,
preprocess_data,
save_trainable_params,
DataCollatorForChatGLM,
TrainerForChatGLM
)
Expand Down Expand Up @@ -41,7 +44,11 @@ def main():
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
trainer.train()
model.save_pretrained(training_args.output_dir)
if finetuning_args.finetuning_type == "p_tuning":
save_trainable_params(training_args.output_dir, model)
elif finetuning_args.finetuning_type == "lora":
model.save_pretrained(training_args.output_dir)
torch.save(training_args, os.path.join(training_args.output_dir, "training_args.bin"))


if __name__ == '__main__':
Expand Down
66 changes: 52 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Type, Dict, Sequence, Optional
from dataclasses import dataclass
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType
from peft import get_peft_model, get_peft_config, TaskType
from arguments import ModelArguments, DataTrainingArguments, FinetuningArguments


Expand Down Expand Up @@ -113,6 +113,11 @@ def prepare_model(model_args, finetuning_args):
use_fast=model_args.use_fast_tokenizer,
**config_kwargs
)

if finetuning_args.finetuning_type == 'p_tuning': # use the built-in p-tuning method in ChatGLM
config.pre_seq_len = finetuning_args.pre_seq_len
config.prefix_projection = finetuning_args.prefix_projection

model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs)
model.config.use_cache = False

Expand All @@ -121,15 +126,27 @@ def prepare_model(model_args, finetuning_args):
model = model.quantize(model_args.quantization_bit)
model.lm_head = CastOutputToFloat(model.lm_head)

if finetuning_args.finetuning_type == 'p_tuning':
logger.info("Fine-tuning method: P-Tuning V2")
model.transformer.prefix_encoder.float() # we cannot use peft since the attention mask is unusual >_<
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, total_params, 100 * trainable_params / total_params
))

if finetuning_args.finetuning_type == 'lora':
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=['query_key_value'] # query_key_value or dense
)
logger.info("Fine-tuning method: LoRA")
peft_config = {
"peft_type": "LORA",
"task_type": TaskType.CAUSAL_LM,
"inference_mode": False,
"r": finetuning_args.lora_rank,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"target_modules": ['query_key_value'] # query_key_value or dense
}
peft_config = get_peft_config(config)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Expand Down Expand Up @@ -199,10 +216,10 @@ def preprocess_function_train(examples):
# return model_inputs

def print_dataset_example(example):
print("input_ids", example["input_ids"])
print("inputs", tokenizer.decode(example["input_ids"]))
print("label_ids", example["labels"])
print("labels", tokenizer.decode(example["labels"]))
print("input_ids:\n", example["input_ids"])
print("inputs:\n", tokenizer.decode(example["input_ids"]))
print("label_ids:\n", example["labels"])
print("labels:\n", tokenizer.decode(example["labels"]))

if training_args.do_train:
train_dataset = raw_datasets["train"]
Expand Down Expand Up @@ -238,6 +255,23 @@ def print_dataset_example(example):
# return eval_dataset


def filter_model_params(model): # filter out the freezed parameters
state_dict = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k]
return filtered_state_dict


def save_trainable_params(save_directory, model):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
filtered_state_dict = filter_model_params(model)
torch.save(filtered_state_dict, os.path.join(save_directory, "adapter_model.bin"))


"""
Note: The ChatGLM tokenizer assigns False on token to be attended in attention mask. In general settings, it should be True.
Refer to: https://huggingface.co/THUDM/chatglm-6b/blob/6650ae3a53c28fc176d06762ca80b05d5ab3792b/tokenization_chatglm.py#L401
Expand Down Expand Up @@ -268,7 +302,11 @@ def _save(self, output_dir: Optional[str] = None, _internal_call: bool = False):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
self.model.save_pretrained(output_dir) # only save peft weights
if hasattr(self.model, "pre_seq_len"): # p-tuning v2
filtered_state_dict = filter_model_params(self.model)
torch.save(filtered_state_dict, os.path.join(output_dir, "adapter_model.bin"))
elif hasattr(self.model, "peft_config"): # LoRA
self.model.save_pretrained(output_dir) # only save peft weights with the built-in method
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))


Expand Down

0 comments on commit 94da192

Please sign in to comment.