Skip to content

Commit

Permalink
restore adam state and save step and epoch for lr decaying to work
Browse files Browse the repository at this point in the history
  • Loading branch information
heewooj committed Jun 9, 2020
1 parent 7939619 commit 86d6319
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
21 changes: 16 additions & 5 deletions jukebox/make_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ def load_checkpoint(path):
print("Restored from {}".format(restore))
return checkpoint

def save_checkpoint(logdir, name, model, opt, metrics, hps):
def save_checkpoint(logger, name, model, opt, metrics, hps):
with t.no_grad():
save_hps = {**hps}
save_hps = {k: v for k,v in save_hps.items() if k not in ['metadata_v2','metadata_v3', 'alignments', 'lyric_processor', 'midi_processor']}
t.save({'hps': save_hps,
'model': model.state_dict(), # should also save bottleneck k's as buffers
'opt': opt.state_dict() if opt is not None else None,
**metrics}, f'{logdir}/checkpoint_{name}.pth.tar')
'step': logger.iters,
'epoch': logger.epoch,
**metrics}, f'{logger.logdir}/checkpoint_{name}.pth.tar')
return

def restore(hps, model, checkpoint_path):
def restore_model(hps, model, checkpoint_path):
model.step = 0
if checkpoint_path != '':
checkpoint = load_checkpoint(checkpoint_path)
Expand All @@ -60,6 +62,15 @@ def restore(hps, model, checkpoint_path):
model.load_state_dict(checkpoint['model'])
if 'step' in checkpoint: model.step = checkpoint['step']

def restore_opt(opt, shd, checkpoint_path):
if not checkpoint_path:
return
checkpoint = load_checkpoint(checkpoint_path)
if "opt" in checkpoint:
opt.load_state_dict(checkpoint['opt'])
if "step" in checkpoint:
shd.step(checkpoint['step'])

def make_vqvae(hps, device='cuda'):
from jukebox.vqvae.vqvae import VQVAE
block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv,
Expand All @@ -82,7 +93,7 @@ def make_vqvae(hps, device='cuda'):
**block_kwargs)

vqvae = vqvae.to(device)
restore(hps, vqvae, hps.restore_vqvae)
restore_model(hps, vqvae, hps.restore_vqvae)
if hps.train and not hps.prior:
print_all(f"Loading vqvae in train mode")
if hps.restore_vqvae != '':
Expand Down Expand Up @@ -166,7 +177,7 @@ def make_prior(hps, vqvae, device='cuda'):
from jukebox.transformer.ops import _convert_conv_weights_to_fp16
prior.apply(_convert_conv_weights_to_fp16)
prior = prior.to(device)
restore(hps, prior, hps.restore_prior)
restore_model(hps, prior, hps.restore_prior)
if hps.train:
print_all(f"Loading prior in train mode")
pass
Expand Down
8 changes: 6 additions & 2 deletions jukebox/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.nn.parallel import DistributedDataParallel

from jukebox.hparams import setup_hparams
from jukebox.make_models import make_vqvae, make_prior, save_checkpoint
from jukebox.make_models import make_vqvae, make_prior, restore_opt, save_checkpoint
from jukebox.utils.logger import init_logging
from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
from jukebox.utils.torch_utils import zero_grad, count_parameters
Expand Down Expand Up @@ -86,6 +86,9 @@ def get_optimizer(model, hps):
# lr scheduler
shd = get_lr_scheduler(opt, hps)

restore_path = hps.restore_prior if hps.prior else hps.restore_vqvae
restore_opt(opt, shd, restore_path)

# fp16 dynamic loss scaler
scalar = None
if hps.fp16:
Expand Down Expand Up @@ -266,7 +269,7 @@ def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_proces
orig_model.eval()
name = 'latest' if hps.prior else f'step_{logger.iters}'
if dist.get_rank() % 8 == 0:
save_checkpoint(logger.logdir, name, orig_model, opt, dict(step=logger.iters), hps)
save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
orig_model.train()
if ema is not None: ema.swap()

Expand Down Expand Up @@ -321,6 +324,7 @@ def run(hps="teeny", port=29500, **kwargs):
for epoch in range(hps.curr_epoch, hps.epochs):
metrics.reset()
data_processor.set_epoch(epoch)
logger.epoch = epoch
if hps.train:
train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps)
train_metrics['epoch'] = epoch
Expand Down

0 comments on commit 86d6319

Please sign in to comment.