Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 21, 2023
1 parent 3ed046a commit eb26e3a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/glmtuner/extras/save_and_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import torch
from typing import Dict
from typing import Dict, Optional

from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
Expand All @@ -12,12 +12,12 @@
logger = get_logger(__name__)


def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
state_dict = model.state_dict()
filtered_state_dict = {}

for k, v in model.named_parameters():
if v.requires_grad:
if (not trainable_only) or v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()

return filtered_state_dict
Expand Down
2 changes: 1 addition & 1 deletion src/glmtuner/tuner/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str,
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model),
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
safe_serialization=self.args.save_safetensors
)
backbone_model.config.use_cache = False
Expand Down

0 comments on commit eb26e3a

Please sign in to comment.