Skip to content

Commit

Permalink
Update sparsified model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin committed Jan 24, 2023
1 parent 165d5be commit 85c4bf9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
6 changes: 0 additions & 6 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
ckpt = torch.load(attempt_download(w) if not str(w).startswith("zoo:")
else sparsezoo_download(w), map_location='cpu') # load
sparsified = bool(ckpt.get("checkpoint_recipe"))

if sparsified:
nc = ckpt["nc"]

ckpt = (
(ckpt.get('ema') or ckpt['model']).to(device).float()
Expand All @@ -95,9 +92,6 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
ckpt.stride = torch.tensor([32.])
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
if sparsified:
ckpt.nc = nc
ckpt.sparsified = True

model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') and not sparsified else ckpt.eval()) # model in eval mode

Expand Down
5 changes: 5 additions & 0 deletions utils/neuralmagic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def load_sparsified_model(

# Load state dict
model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True)

model.hyp = ckpt.get("hyp")
model.nc = ckpt.get("nc")
model.sparsified = True

return model


Expand Down

0 comments on commit 85c4bf9

Please sign in to comment.