Skip to content

Commit

Permalink
[Bug-fix] Support for exporting intermediate checkpoints (#189)
Browse files Browse the repository at this point in the history
* Save recipes with intermediate checkpoints

* Add epoch stage adjustment
  • Loading branch information
KSGulin committed Mar 14, 2023
1 parent 75799e1 commit 4754f18
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w) if not str(w).startswith("zoo:")
else sparsezoo_download(w), map_location='cpu') # load
sparsified = bool(ckpt.get("checkpoint_recipe"))
sparsified = bool(ckpt.get("checkpoint_recipe") or ckpt.get("intermediate_recipe"))

ckpt = (
(ckpt.get('ema') or ckpt['model']).to(device).float()
Expand Down
14 changes: 8 additions & 6 deletions utils/neuralmagic/sparsification_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,11 @@ def rescale_gradient_accumulation(
effective_batch_size = batch_size * accumulate
batch_size = max(batch_size // QAT_BATCH_SCALE, 1)
accumulate = effective_batch_size // batch_size

self.log_console(
f"Batch size rescaled to {batch_size} with {accumulate} gradient "
"accumulation steps for QAT"
)
f"Batch size rescaled to {batch_size} with {accumulate} gradient "
"accumulation steps for QAT"
)

if accumulate * batch_size != effective_batch_size:
self.log_console(
Expand Down Expand Up @@ -461,7 +461,7 @@ def update_state_dict_for_saving(
"""
# checkpoint recipe saved with final models, for state re-construction upon
# loading for validation or additional stage of sparsification
checkpoint_recipe = self.get_final_checkpoint_recipe() if final_epoch else None
composed_recipe = self.get_final_checkpoint_recipe()

# Pickling is not supported for quantized models for a subset of the supported
# torch versions, thus all sparse models are saved via their state dict
Expand All @@ -470,7 +470,9 @@ def update_state_dict_for_saving(
"yaml": ckpt["model"].yaml,
"ema": ckpt["ema"].state_dict() if ema_enabled else None,
"updates": ckpt["updates"] if ema_enabled else None,
"checkpoint_recipe": str(checkpoint_recipe) if checkpoint_recipe else None,
"checkpoint_recipe": str(composed_recipe) if final_epoch else None,
# Saved to support export of intermediate checkpoints
"intermediate_recipe": str(composed_recipe) if not final_epoch else None,
"epoch": -1 if final_epoch else ckpt["epoch"],
"nc": number_classes,
}
Expand Down
9 changes: 6 additions & 3 deletions utils/neuralmagic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ def load_sparsified_model(
model = Yolov5Model(ckpt.get("yaml"))
model = update_model_bottlenecks(model).to(device)
checkpoint_manager = ScheduledModifierManager.from_yaml(
ckpt["checkpoint_recipe"]
ckpt["checkpoint_recipe"] or ckpt["intermediate_recipe"]
)
checkpoint_manager.apply_structure(
model, ckpt["epoch"] + ALMOST_ONE if ckpt["epoch"] >= 0 else float("inf")
epoch = (
checkpoint_manager.get_last_start_epoch() + ckpt["epoch"] + ALMOST_ONE
if ckpt["epoch"] >= 0
else float("inf")
)
checkpoint_manager.apply_structure(model, epoch)

# Load state dict
model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True)
Expand Down

0 comments on commit 4754f18

Please sign in to comment.