Skip to content

Commit

Permalink
fix loading best model
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 27, 2023
1 parent c4d309e commit 194ca0d
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/utils/peft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model

if self.finetuning_args.finetuning_type == "lora":
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)
else: # freeze/full-tuning or p_tuning
load_trainable_params(backbone_model, self.state.best_model_checkpoint)

0 comments on commit 194ca0d

Please sign in to comment.