Skip to content

Commit

Permalink
override load_best_model
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 19, 2023
1 parent d900d56 commit ca873e6
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 7 deletions.
7 changes: 4 additions & 3 deletions examples/train_sft_with_dev_set.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ CUDA_VISIBLE_DEVICES=0 python ../src/train_sft.py \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--evaluation_strategy steps \
--save_strategy steps \
--logging_steps 10 \
--save_steps 1000 \
--eval_steps 100 \
--save_steps 100 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--dev_ratio 0.01 \
--evaluation_strategy steps \
--eval_steps 100 \
--load_best_model_at_end \
--plot_loss \
--fp16
4 changes: 4 additions & 0 deletions src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,17 @@ def save_trainable_params(save_directory: os.PathLike, model: torch.nn.Module) -


def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
model = unwrap_model(model)

weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights."
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys


def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
model = unwrap_model(model)

valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
assert os.path.exists(valuehead_file), f"Provided path ({checkpoint_dir}) does not contain the valuehead weights."
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
Expand Down
16 changes: 15 additions & 1 deletion src/utils/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .other import (
get_logger,
save_trainable_params,
load_trainable_params,
load_valuehead_params,
FINETUNING_ARGS_NAME
)

Expand Down Expand Up @@ -49,6 +51,8 @@ def compute_loss(self, model, inputs, return_outputs=False):
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs)
Expand All @@ -60,7 +64,7 @@ def compute_loss(self, model, inputs, return_outputs=False):

def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoints.
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Expand All @@ -72,3 +76,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str,
save_trainable_params(output_dir, self.model)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME))

def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
load_trainable_params(self.model, self.state.best_model_checkpoint)
load_valuehead_params(self.model, self.state.best_model_checkpoint)
4 changes: 2 additions & 2 deletions src/utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ def save_state(self, output_dir: Optional[str] = None) -> None:

def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves trainable parameters as model checkpoints. We use `self.model.pretrained_model` to refer to the backbone model.
Saves trainable parameters as model checkpoint.
Override to inject custom behavior.
Subclass and override to inject custom behavior.
"""
self.accelerator.wait_for_everyone() # must be executed before is_world_process_zero()
if not self.is_world_process_zero():
Expand Down
12 changes: 11 additions & 1 deletion src/utils/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .other import (
get_logger,
save_trainable_params,
load_trainable_params,
IGNORE_INDEX,
FINETUNING_ARGS_NAME,
PREDICTION_FILE_NAME
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, finetuning_args: FinetuningArguments, *args, **kwargs):

def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoints.
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Expand All @@ -94,6 +95,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str,
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME))

def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
load_trainable_params(self.model, self.state.best_model_checkpoint)

def prediction_step(
self,
model: torch.nn.Module,
Expand Down

0 comments on commit ca873e6

Please sign in to comment.