diff --git a/models/experimental.py b/models/experimental.py index a4b4445dd699..7deef125ec00 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -79,7 +79,7 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): for w in weights if isinstance(weights, list) else [weights]: ckpt = torch.load(attempt_download(w) if not str(w).startswith("zoo:") else sparsezoo_download(w), map_location='cpu') # load - sparsified = bool(ckpt.get("checkpoint_recipe")) + sparsified = bool(ckpt.get("checkpoint_recipe") or ckpt.get("intermediate_recipe")) ckpt = ( (ckpt.get('ema') or ckpt['model']).to(device).float() diff --git a/utils/neuralmagic/sparsification_manager.py b/utils/neuralmagic/sparsification_manager.py index b74f9ce2a787..ff743538aaa7 100644 --- a/utils/neuralmagic/sparsification_manager.py +++ b/utils/neuralmagic/sparsification_manager.py @@ -394,11 +394,11 @@ def rescale_gradient_accumulation( effective_batch_size = batch_size * accumulate batch_size = max(batch_size // QAT_BATCH_SCALE, 1) accumulate = effective_batch_size // batch_size - + self.log_console( - f"Batch size rescaled to {batch_size} with {accumulate} gradient " - "accumulation steps for QAT" - ) + f"Batch size rescaled to {batch_size} with {accumulate} gradient " + "accumulation steps for QAT" + ) if accumulate * batch_size != effective_batch_size: self.log_console( @@ -461,7 +461,7 @@ def update_state_dict_for_saving( """ # checkpoint recipe saved with final models, for state re-construction upon # loading for validation or additional stage of sparsification - checkpoint_recipe = self.get_final_checkpoint_recipe() if final_epoch else None + composed_recipe = self.get_final_checkpoint_recipe() # Pickling is not supported for quantized models for a subset of the supported # torch versions, thus all sparse models are saved via their state dict @@ -470,7 +470,9 @@ def update_state_dict_for_saving( "yaml": ckpt["model"].yaml, "ema": ckpt["ema"].state_dict() if ema_enabled else None, "updates": ckpt["updates"] if ema_enabled else None, - "checkpoint_recipe": str(checkpoint_recipe) if checkpoint_recipe else None, + "checkpoint_recipe": str(composed_recipe) if final_epoch else None, + # Saved to support export of intermediate checkpoints + "intermediate_recipe": str(composed_recipe) if not final_epoch else None, "epoch": -1 if final_epoch else ckpt["epoch"], "nc": number_classes, } diff --git a/utils/neuralmagic/utils.py b/utils/neuralmagic/utils.py index e54c1a7e2e49..4ed680b2347c 100644 --- a/utils/neuralmagic/utils.py +++ b/utils/neuralmagic/utils.py @@ -106,11 +106,14 @@ def load_sparsified_model( model = Yolov5Model(ckpt.get("yaml")) model = update_model_bottlenecks(model).to(device) checkpoint_manager = ScheduledModifierManager.from_yaml( - ckpt["checkpoint_recipe"] + ckpt["checkpoint_recipe"] or ckpt["intermediate_recipe"] ) - checkpoint_manager.apply_structure( - model, ckpt["epoch"] + ALMOST_ONE if ckpt["epoch"] >= 0 else float("inf") + epoch = ( + checkpoint_manager.get_last_start_epoch() + ckpt["epoch"] + ALMOST_ONE + if ckpt["epoch"] >= 0 + else float("inf") ) + checkpoint_manager.apply_structure(model, epoch) # Load state dict model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True)