-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Legacy pickled models support #162
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import os | ||
import sys | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
|
@@ -11,7 +12,7 @@ | |
from models.yolo import Model as Yolov5Model | ||
from utils.dataloaders import create_dataloader | ||
from utils.general import LOGGER, check_dataset, check_yaml, colorstr | ||
from utils.neuralmagic.quantization import update_model_bottlenecks | ||
from utils.neuralmagic.quantization import _Add, update_model_bottlenecks | ||
from utils.torch_utils import ModelEMA | ||
|
||
__all__ = [ | ||
|
@@ -28,6 +29,20 @@ | |
RANK = int(os.getenv("RANK", -1)) | ||
ALMOST_ONE = 1 - 1e-9 # for incrementing epoch to be applied to recipe | ||
|
||
# In previous integrations of NM YOLOv5, we were pickling models as long as they are | ||
# not quantized. We've now changed to never pickling a model touched by us. This | ||
# namespace hacking is meant to address backwards compatibility with previously | ||
# pickled, pruned models. | ||
import models | ||
from models import common | ||
setattr(common, "_Add", _Add) # Definition of the _Add module has moved | ||
|
||
# If using yolov5 as a repo and not a package, allow loading of models pickled w package | ||
if "yolov5" not in sys.modules: | ||
sys.modules["yolov5"] = "" | ||
sys.modules["yolov5.models"] = models | ||
sys.modules["yolov5.models.common"] = common | ||
KSGulin marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+41
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this does work after all? You were saying it wasn't working at first right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue I was running into with this was actually for the |
||
|
||
|
||
class ToggleableModelEMA(ModelEMA): | ||
""" | ||
|
@@ -81,20 +96,27 @@ def load_sparsified_model( | |
# Load checkpoint if not yet loaded | ||
ckpt = ckpt if isinstance(ckpt, dict) else torch.load(ckpt, map_location=device) | ||
|
||
# Construct randomly initialized model model and apply sparse structure modifiers | ||
model = Yolov5Model(ckpt.get("yaml")) | ||
model = update_model_bottlenecks(model).to(device) | ||
checkpoint_manager = ScheduledModifierManager.from_yaml(ckpt["checkpoint_recipe"]) | ||
checkpoint_manager.apply_structure( | ||
model, ckpt["epoch"] + ALMOST_ONE if ckpt["epoch"] >= 0 else float("inf") | ||
) | ||
if isinstance(ckpt["model"], torch.nn.Module): | ||
KSGulin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model = ckpt["model"] | ||
|
||
else: | ||
# Construct randomly initialized model model and apply sparse structure modifiers | ||
model = Yolov5Model(ckpt.get("yaml")) | ||
model = update_model_bottlenecks(model).to(device) | ||
checkpoint_manager = ScheduledModifierManager.from_yaml( | ||
ckpt["checkpoint_recipe"] | ||
) | ||
checkpoint_manager.apply_structure( | ||
model, ckpt["epoch"] + ALMOST_ONE if ckpt["epoch"] >= 0 else float("inf") | ||
) | ||
|
||
# Load state dict | ||
model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True) | ||
# Load state dict | ||
model.load_state_dict(ckpt["ema"] or ckpt["model"], strict=True) | ||
KSGulin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model.hyp = ckpt.get("hyp") | ||
model.nc = ckpt.get("nc") | ||
|
||
model.hyp = ckpt.get("hyp") | ||
model.nc = ckpt.get("nc") | ||
model.sparsified = True | ||
model.float() | ||
|
||
return model | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah I love this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was hoping to import it directly in common.py, but quickly found myself in circular import hell. This way is probably more intentional anyway