Skip to content

Commit

Permalink
remove compression_model from lm checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Dec 11, 2023
1 parent 89baee7 commit 26dfbe5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

Adding stereo models.

Removed compression model state from the LM checkpoints, for consistency, it
should always be loaded from the original `compression_model_checkpoint`.


## [1.1.0] - 2023-11-06

Expand Down
20 changes: 18 additions & 2 deletions audiocraft/solvers/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
from ..utils.cache import CachedBatchWriter, CachedBatchLoader
from ..utils.samples.manager import SampleManager
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once, model_hash


class MusicGenSolver(base.StandardSolver):
Expand Down Expand Up @@ -143,7 +143,7 @@ def build_model(self) -> None:
# initialize optimization
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
self.register_stateful('model', 'optimizer', 'lr_scheduler')
self.register_best_state('model')
self.autocast_dtype = {
'float16': torch.float16, 'bfloat16': torch.bfloat16
Expand Down Expand Up @@ -181,6 +181,22 @@ def load_state_dict(self, state: dict) -> None:
key = prefix + key
assert key not in model_state
model_state[key] = value
if 'compression_model' in state:
# We used to store the `compression_model` state in the checkpoint, however
# this is in general not needed, as the compression model should always be readable
# from the original `cfg.compression_model_checkpoint` location.
compression_model_state = state.pop('compression_model')
before_hash = model_hash(self.compression_model)
self.compression_model.load_state_dict(compression_model_state)
after_hash = model_hash(self.compression_model)
if before_hash != after_hash:
raise RuntimeError(
"The compression model state inside the checkpoint is different"
" from the one obtained from compression_model_checkpoint..."
"We do not support altering the compression model inside the LM "
"checkpoint as parts of the code, in particular for running eval post-training "
"will use the compression_model_checkpoint as the source of truth.")

super().load_state_dict(state)

def load_from_pretrained(self, name: str):
Expand Down

0 comments on commit 26dfbe5

Please sign in to comment.