Skip to content

Simple 4-bit/8-bit LoRA fine-tuning for ChatGLM2 with peft and transformers.Trainer.

License

Notifications You must be signed in to change notification settings

FreddyBanana/ChatGLM2-LoRA-Trainer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ChatGLM2-LoRA-Trainer

简介 / Introduction

本仓库利用peft库与transformers.Trainer,实现对ChatGLM2的简单4-bit/8-bit LoRA微调。(其它LLM应该也行,只要稍作修改)

This repo uses peft and transformers.Trainer to achieve simple 4-bit/8-bit LoRA fine-tuning for ChatGLM2. (You can also use this repo for other LLM with minor modifications)

安装依赖 / Installing the dependencies

$ pip install -r requirement.txt

requirement.txt:

datasets==2.13.1
protobuf
transformers==4.30.2
cpm_kernels
torch>=2.0
mdtex2html
sentencepiece
accelerate
git+https://github.com/huggingface/peft.git
bitsandbytes
loralib
scipy

参数/config

文件config.py参数如下:

  • MICRO_BATCH_SIZE,每块GPU的batch size大小。
  • BATCH_SIZE,真正的batch size,当每个batch的处理样本数达到BATCH_SIZE时,进行梯度更新。
  • EPOCHS,总训练代数。
  • WARMUP_STEPS,预热步数。
  • LEARNING_RATE,学习率。
  • CONTEXT_LEN,context字段截断长度(对应json文件的context)。
  • TARGET_LEN,target字段截断长度(对应json文件的target)。
  • TEXT_LEN,text字段截断长度(对应txt文件的文本)。
  • LORA_R,LoRA低秩的秩数。
  • LORA_ALPHA,LoRA的alpha。
  • LORA_DROPOUT,LoRA层的Dropout率。
  • MODEL_NAME,模型名称(huggingface仓库地址)。
  • LOGGING_STEPS,日志步数,即训练的时候输出loss的间隔步数。
  • OUTPUT_DIR,输出LoRA权重的存放文件夹位置。
  • DATA_PATH,数据集文件位置。
  • DATA_TYPE,数据集文件类型,可选json或txt。
  • SAVE_STEPS,保存LoRA权重的间隔步数。
  • SAVE_TOTAL_LIMIT,保存LoRA权重文件的总数(不包括最终权重)。
  • PROMPT,推理时的prompt。
  • TEMPERATURE,推理时的温度,调整模型的创造力。
  • LORA_CHECKPOINT_DIR,待推理LoRA权重的文件夹位置。
  • BIT_4,使用4bit量化+LoRA微调。
  • BIT_8,使用8bit量化+LoRA微调。

The parameters in config.py are as follows:

  • MICRO_BATCH_SIZE,batch size on each device。
  • BATCH_SIZE,when the number of processed samples in each split batch reaches BATCH_SIZE, update the gradient.
  • EPOCHS,training epochs。
  • WARMUP_STEPS,warmup steps。
  • LEARNING_RATE,learning rate of fine-tuning。
  • CONTEXT_LEN,truncation length of context (in json)。
  • TARGET_LEN,truncation length of target (in json)。
  • TEXT_LEN,truncation length of text (in txt)。
  • LORA_R,LoRA low rank。
  • LORA_ALPHA,LoRA Alpha。
  • LORA_DROPOUT,LoRA dropout。
  • MODEL_NAME,model name (huggingface repo address)。
  • LOGGING_STEPS,the number of interval steps for outputting loss during training。
  • OUTPUT_DIR,the storage folder location for LoRA weights。
  • DATA_PATH,the location of your dataset file。
  • DATA_TYPE,the type of your dataset file, including json and txt。
  • SAVE_STEPS,the number of interval steps to save LoRA weights。
  • SAVE_TOTAL_LIMIT,the total number of LoRA weight files saved (excluding the final one)。
  • PROMPT,your prompt when inference。
  • TEMPERATURE,the temperature when inference, adjusting the creativity of LLM。
  • LORA_CHECKPOINT_DIR,folder location for LoRA weights to be inferred。
  • BIT_4,use 4-bit。
  • BIT_8,use 8-bit。

数据集文件/Dataset files

1)json

json文件格式如下:

The JSON file format is as follows:

{"context":question1, "target":answer1}{"context":question2, "target":answer2}...

2)txt

txt文件格式如下:

The txt file format is as follows:

sentence1
sentence2
sentence3
...

使用方法 / Usage

1)训练/train

$ sh train.sh

train.sh:

python main.py \
	--MICRO_BATCH_SIZE 8 \
	--BATCH_SIZE 16 \
	--EPOCHS 50 \
	--LEARNING_RATE 5e-4 \
	--CONTEXT_LEN 64 \
	--TARGET_LEN 192 \
	--LORA_R 16 \
	--LORA_DROPOUT 0.5 \
	--MODEL_NAME THUDM/chatglm2-6b \
	--OUTPUT_DIR ./output_model \
	--DATA_PATH ./new_train.json \
	--DATA_TYPE json \
	--SAVE_STEPS 1000 \
	--BIT_4

2)推理/inference

$ sh inference.sh

inference.sh:

python inference.py \
	--CONTEXT_LEN 256 \
	--MODEL_NAME THUDM/chatglm2-6b \
	--LORA_CHECKPOINT_DIR ./output_model/checkpoint-4000/ \
	--BIT_4 \
	--PROMPT "put your prompt here"

参考 / Reference

THUDM/ChatGLM2-6B: ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型 (github.com)

Fine_Tuning_LLama | Kaggle

mymusise/ChatGLM-Tuning: 一种平价的chatgpt实现方案, 基于ChatGLM-6B + LoRA (github.com)

更新日志/ChangeLog

  • [2023/08/02]:更新了LoraConfig的target_modules。
  • [2023/07/27]:对于QA的训练,更新loss的计算目标,只计算问题部分(json里面的target字段)的loss。
  • [2023/07/25]:添加4-bit量化LoRA训练。
  • [2023/07/24]:添加eos_token_id,解决重复输出的问题。

About

Simple 4-bit/8-bit LoRA fine-tuning for ChatGLM2 with peft and transformers.Trainer.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published