From 85c4bf99b5947173afb6b900ce469c936866c9b7 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Tue, 24 Jan 2023 12:13:13 +0000 Subject: [PATCH] Update sparsified model loading --- models/experimental.py | 6 ------ utils/neuralmagic/utils.py | 5 +++++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/models/experimental.py b/models/experimental.py index 5b0108f96f1d..a4b4445dd699 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -80,9 +80,6 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): ckpt = torch.load(attempt_download(w) if not str(w).startswith("zoo:") else sparsezoo_download(w), map_location='cpu') # load sparsified = bool(ckpt.get("checkpoint_recipe")) - - if sparsified: - nc = ckpt["nc"] ckpt = ( (ckpt.get('ema') or ckpt['model']).to(device).float() @@ -95,9 +92,6 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): ckpt.stride = torch.tensor([32.]) if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)): ckpt.names = dict(enumerate(ckpt.names)) # convert to dict - if sparsified: - ckpt.nc = nc - ckpt.sparsified = True model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') and not sparsified else ckpt.eval()) # model in eval mode diff --git a/utils/neuralmagic/utils.py b/utils/neuralmagic/utils.py index 17dd74844f86..6d46216fbecd 100644 --- a/utils/neuralmagic/utils.py +++ b/utils/neuralmagic/utils.py @@ -91,6 +91,11 @@ def load_sparsified_model( # Load state dict model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True) + + model.hyp = ckpt.get("hyp") + model.nc = ckpt.get("nc") + model.sparsified = True + return model