Skip to content

Commit

Permalink
update code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Apr 26, 2023
1 parent 469d010 commit a108614
Show file tree
Hide file tree
Showing 11 changed files with 383 additions and 288 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ sys.path.append("src")
from src import load_pretrained, ModelArguments
model_args = ModelArguments(checkpoint_dir=path_to_checkpoint)
model, tokenizer = load_pretrained(model_args)
model = model.half().cuda()
model = model.cuda()
model.eval()
# model.generate, model.chat()...
```

Expand Down
3 changes: 2 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ sys.path.append("src")
from src import load_pretrained, ModelArguments
model_args = ModelArguments(checkpoint_dir=path_to_checkpoint)
model, tokenizer = load_pretrained(model_args)
model = model.half().cuda()
model = model.cuda()
model.eval()
# model.generate, model.chat()...
```

Expand Down
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions examples/evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
--do_eval \
--dataset alpaca_gpt4_zh \
--dataset_dir ../data \
--output_dir output_eval \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--overwrite_cache \
--per_device_eval_batch_size 8 \
--max_samples 20 \
--max_samples 50 \
--predict_with_generate
3 changes: 1 addition & 2 deletions examples/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
--dataset alpaca_gpt4_zh \
--dataset_dir ../data \
--finetuning_type lora \
--output_dir output_finetune \
--output_dir path_to_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--max_samples 10000 \
--learning_rate 5e-5 \
--num_train_epochs 1.0 \
--fp16
12 changes: 5 additions & 7 deletions src/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
prepare_data,
preprocess_data,
plot_loss,
DataCollatorForChatGLM,
Seq2SeqDataCollatorForChatGLM,
ComputeMetrics,
TrainerForChatGLM
Seq2SeqTrainerForChatGLM
)


Expand All @@ -20,9 +20,9 @@ def main():
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args()
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, is_trainable=training_args.do_train)
model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, is_trainable=training_args.do_train)
dataset = preprocess_data(dataset, tokenizer, data_args, training_args)
data_collator = DataCollatorForChatGLM(
data_collator = Seq2SeqDataCollatorForChatGLM(
tokenizer=tokenizer,
model=model,
ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss,
Expand All @@ -36,7 +36,7 @@ def main():
data_args.num_beams is not None else training_args.generation_num_beams

# Initialize our Trainer
trainer = TrainerForChatGLM(
trainer = Seq2SeqTrainerForChatGLM(
finetuning_args=finetuning_args,
model=model,
args=training_args,
Expand Down Expand Up @@ -67,14 +67,12 @@ def main():

# Evaluation
if training_args.do_eval:
model = model.half() # don't use `--fp16` argument at evaluation
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Predict
if training_args.do_predict:
model = model.half()
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
Expand Down
3 changes: 2 additions & 1 deletion src/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def main():
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
model = model.half().cuda()
model = model.cuda()
model.eval()

history = []
print(welcome)
Expand Down
9 changes: 6 additions & 3 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DataCollatorForChatGLM,
preprocess_data
)

from .seq2seq import (
Seq2SeqDataCollatorForChatGLM,
ComputeMetrics,
TrainerForChatGLM
Seq2SeqTrainerForChatGLM
)

from .config import ModelArguments
Expand Down
Loading

0 comments on commit a108614

Please sign in to comment.