Skip to content

Commit

Permalink
Legacy pickled models support (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin committed Jan 31, 2023
1 parent c1917d7 commit 54c1108
Showing 1 changed file with 34 additions and 12 deletions.
46 changes: 34 additions & 12 deletions utils/neuralmagic/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

Expand All @@ -11,7 +12,7 @@
from models.yolo import Model as Yolov5Model
from utils.dataloaders import create_dataloader
from utils.general import LOGGER, check_dataset, check_yaml, colorstr
from utils.neuralmagic.quantization import update_model_bottlenecks
from utils.neuralmagic.quantization import _Add, update_model_bottlenecks
from utils.torch_utils import ModelEMA

__all__ = [
Expand All @@ -28,6 +29,20 @@
RANK = int(os.getenv("RANK", -1))
ALMOST_ONE = 1 - 1e-9 # for incrementing epoch to be applied to recipe

# In previous integrations of NM YOLOv5, we were pickling models as long as they are
# not quantized. We've now changed to never pickling a model touched by us. This
# namespace hacking is meant to address backwards compatibility with previously
# pickled, pruned models.
import models
from models import common
setattr(common, "_Add", _Add) # Definition of the _Add module has moved

# If using yolov5 as a repo and not a package, allow loading of models pickled w package
if "yolov5" not in sys.modules:
sys.modules["yolov5"] = ""
sys.modules["yolov5.models"] = models
sys.modules["yolov5.models.common"] = common


class ToggleableModelEMA(ModelEMA):
"""
Expand Down Expand Up @@ -81,20 +96,27 @@ def load_sparsified_model(
# Load checkpoint if not yet loaded
ckpt = ckpt if isinstance(ckpt, dict) else torch.load(ckpt, map_location=device)

# Construct randomly initialized model model and apply sparse structure modifiers
model = Yolov5Model(ckpt.get("yaml"))
model = update_model_bottlenecks(model).to(device)
checkpoint_manager = ScheduledModifierManager.from_yaml(ckpt["checkpoint_recipe"])
checkpoint_manager.apply_structure(
model, ckpt["epoch"] + ALMOST_ONE if ckpt["epoch"] >= 0 else float("inf")
)
if isinstance(ckpt["model"], torch.nn.Module):
model = ckpt["model"]

else:
# Construct randomly initialized model model and apply sparse structure modifiers
model = Yolov5Model(ckpt.get("yaml"))
model = update_model_bottlenecks(model).to(device)
checkpoint_manager = ScheduledModifierManager.from_yaml(
ckpt["checkpoint_recipe"]
)
checkpoint_manager.apply_structure(
model, ckpt["epoch"] + ALMOST_ONE if ckpt["epoch"] >= 0 else float("inf")
)

# Load state dict
model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True)
# Load state dict
model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True)
model.hyp = ckpt.get("hyp")
model.nc = ckpt.get("nc")

model.hyp = ckpt.get("hyp")
model.nc = ckpt.get("nc")
model.sparsified = True
model.float()

return model

Expand Down

0 comments on commit 54c1108

Please sign in to comment.