Skip to content

Commit

Permalink
Styling
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin committed Feb 21, 2023
1 parent 2dd2786 commit 6a0896d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
17 changes: 13 additions & 4 deletions utils/neuralmagic/quantization.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import torch
import onnx
from models.common import Bottleneck, GhostBottleneck, DetectMultiBackend
import torch

from models.common import Bottleneck, DetectMultiBackend, GhostBottleneck

try:
from torch.nn.quantized import FloatFunctional
except Exception:
FloatFunctional = None

__all__ = ["NMGhostBottleneck", "NMBottleneck", "update_model_bottlenecks", "is_quantized"]
__all__ = [
"NMGhostBottleneck",
"NMBottleneck",
"update_model_bottlenecks",
"is_quantized",
]


def is_quantized(model: DetectMultiBackend) -> bool:
"""
Check if DetectMultiBackend model is quantized
"""
onnx_model = onnx.load_model(model.ds_engine.model_path)
return onnx_model.graph.input[0].type.tensor_type.elem_type == onnx.TensorProto.UINT8
return (
onnx_model.graph.input[0].type.tensor_type.elem_type == onnx.TensorProto.UINT8
)


class _Add(torch.nn.Module):
Expand Down
4 changes: 1 addition & 3 deletions utils/neuralmagic/sparsification_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,7 @@ def initialize(
# construct a ToggleableModelEMA from ModelEMA, allowing for on/off toggle
if ema:
# QAT is active at the start epoch, disable ema
qat_active = (
self.has_qat_phase and start_epoch >= self.first_qat_epoch
)
qat_active = self.has_qat_phase and start_epoch >= self.first_qat_epoch

ema = load_ema(
ema.ema.state_dict(),
Expand Down
5 changes: 2 additions & 3 deletions utils/neuralmagic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# pickled, pruned models.
import models
from models import common

setattr(common, "_Add", _Add) # Definition of the _Add module has moved

# If using yolov5 as a repo and not a package, allow loading of models pickled w package
Expand Down Expand Up @@ -62,9 +63,7 @@ def sparsezoo_download(path: str, recipe: Optional[str] = None) -> str:
"""
Loads model from the SparseZoo and override the path with the new download path
"""
return download_framework_model_by_recipe_type(
Model(path), recipe, "pt"
)
return download_framework_model_by_recipe_type(Model(path), recipe, "pt")


def load_ema(
Expand Down

0 comments on commit 6a0896d

Please sign in to comment.