Skip to content

Commit

Permalink
Merge branch 'main' of github.com:taishan1994/ChatGLM-LoRA-Tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
taishan1994 committed May 24, 2023
2 parents 79b4e86 + d8544bc commit 946dfcf
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions train_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def on_save(self,
# args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
# )

peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
peft_model_path = os.path.join(args.output_dir, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
return control

Expand Down Expand Up @@ -121,7 +121,7 @@ def main():
"ignore_pad_token_for_loss": True,
"train_batch_size": 12,
"gradient_accumulation_steps": 1,
"save_dir": "./checkpoint/msra/train_trainer/adapter_model/",
"save_dir": "./checkpoint/msra/train_trainer/",
"num_train_epochs": 1,
"local_rank": -1,
"log_steps": 10,
Expand All @@ -133,10 +133,11 @@ def main():
args = config_parser.parse_main()

pprint(vars(args))
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)

with open(os.path.join(args.save_dir, "train_args.json"), "w") as fp:
tmp_dir = os.path.join(args.save_dir, "adapter_model")
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

with open(os.path.join(tmp_dir, "train_args.json"), "w") as fp:
json.dump(vars(args), fp, ensure_ascii=False, indent=2)

with open(args.deepspeed_jaon_path, "r") as fp:
Expand Down

0 comments on commit 946dfcf

Please sign in to comment.