Skip to content

Commit

Permalink
fix fatal error in finetuning
Browse files Browse the repository at this point in the history
caused by utils.py:L135: the model.half() cannot be used in finetuning
  • Loading branch information
hiyouga committed Apr 10, 2023
1 parent 384968b commit 8780d8c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
2 changes: 1 addition & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field


CHATGLM_LASTEST_HASH = 'acd41f77311be8584836edc2fc7251d5b6e65840'
CHATGLM_LASTEST_HASH = 'cde457b39fe0670b10dd293909aab17387ea2c80'
DATASETS = {
"alpaca_en": {
"train": {
Expand Down
7 changes: 5 additions & 2 deletions finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ CUDA_VISIBLE_DEVICES=0 python finetune_chatglm.py \
--output_dir output \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--logging_steps 50 \
--gradient_accumulation_steps 4 \
--logging_steps 10 \
--save_steps 1000 \
--warmup_steps 100 \
--fp16 \
--num_train_epochs 1.0
1 change: 1 addition & 0 deletions finetune_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def main():
model.save_pretrained(training_args.output_dir)

# Testing
model = model.half()
model.eval()
response, _ = model.chat(tokenizer, query='你好', history=[])
print(response)
Expand Down
24 changes: 14 additions & 10 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
logger = logging.getLogger(__name__)


class CastOutputToFloat(torch.nn.Sequential):

def forward(self, x):
return super().forward(x).to(torch.float32)


def prepare_args():
# Load arguments
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, FinetuningArguments))
Expand Down Expand Up @@ -113,14 +119,14 @@ def prepare_model(model_args, finetuning_args):
**config_kwargs
)
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs)
model.config.use_cache = False

if model_args.quantization_bit is not None:
print("Quantized to {} bit".format(model_args.quantization_bit))
model = model.quantize(model_args.quantization_bit)
model.lm_head = CastOutputToFloat(model.lm_head)

if finetuning_args.finetuning_type == 'lora':
for param in model.parameters():
param.requires_grad_(False)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
Expand All @@ -132,8 +138,6 @@ def prepare_model(model_args, finetuning_args):
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

model = model.half()

return tokenizer, model


Expand Down Expand Up @@ -174,6 +178,7 @@ def format_example(examples):
yield prompt, answer

def preprocess_function_train(examples):
# build inputs with format `X [gMASK] [BOS] Y [EOP]`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
Expand All @@ -193,6 +198,7 @@ def preprocess_function_train(examples):
return model_inputs

def preprocess_function_eval(examples):
# build inputs with format `X [gMASK] [BOS]`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt)
Expand Down Expand Up @@ -250,7 +256,7 @@ def print_dataset_example(example):
"""
Note: The ChatGLM tokenizer assigns False on token to be attended in attention mask. In general settings, it should be True.
Refer to: https://huggingface.co/THUDM/chatglm-6b/blob/6650ae3a53c28fc176d06762ca80b05d5ab3792b/tokenization_chatglm.py#L401
Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py#L166
Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/aa65c492bb788e144712daab42bc5d11c2761591/train.py#L166
"""
@dataclass
class DataCollatorForChatGLM(DataCollatorWithPadding):
Expand All @@ -263,17 +269,15 @@ def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.T
label_pad_token_id = IGNORE_INDEX if self.data_args.ignore_pad_token_for_loss else self.tokenizer.pad_token_id
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=label_pad_token_id)
features = {"input_ids": input_ids, "labels": labels}
return super().__call__(features)
# return super().__call__(features) # enable generating attention mask and position ids
return features


"""
Inspired by: https://github.com/mymusise/ChatGLM-Tuning/blob/master/finetune.py#L52
Inspired by: https://github.com/mymusise/ChatGLM-Tuning/blob/997393046a49510e6cda36962f9a399297959311/finetune.py#L52
"""
class TrainerForChatGLM(Trainer):

def compute_loss(self, model, inputs, return_outputs=False):
return model(**inputs).loss

def _save(self, output_dir: Optional[str] = None, _internal_call: bool = False):
from transformers.trainer import TRAINING_ARGS_NAME
output_dir = output_dir if output_dir is not None else self.args.output_dir
Expand Down

0 comments on commit 8780d8c

Please sign in to comment.