Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SparseML Integration to V6.1 #26

Merged
merged 17 commits into from
Apr 8, 2022
Merged
Prev Previous commit
Next Next commit
Update: multi-stage recipe support
  • Loading branch information
KSGulin committed Apr 6, 2022
commit e5999d577cd4172bea7ac78cca00020d467c74f5
17 changes: 3 additions & 14 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,24 +498,13 @@ def load_checkpoint(
p.requires_grad = True

# load sparseml recipe for applying pruning and quantization
recipe = (ckpt['recipe'] if ('recipe' in ckpt) else None) if resume else recipe
sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe)
recipe_new = (ckpt['recipe'] if ('recipe' in ckpt) else None) if resume else recipe
recipe_base = None if resume else ckpt['recipe']
sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe_new, recipe_base)
exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume
loaded = False

if not train_type:
# update param names for yolov5x5 models (model.x -> model.model.x)
'''
if ('version' not in ckpt or ckpt['version'] < 6) and sparseml_wrapper.manager is not None:
for modifier in sparseml_wrapper.manager.pruning_modifiers:
updated_params = []
for param in modifier.params:
updated_params.append(
"model." + param if (param.startswith('model.') and
not param.startswith('model.model.')) else param
)
modifier.params = updated_params
'''
# apply the recipe to create the final state of the model when not training
sparseml_wrapper.apply()
else:
Expand Down
11 changes: 7 additions & 4 deletions utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def check_download_sparsezoo_weights(path):


class SparseMLWrapper(object):
def __init__(self, model, recipe):
self.enabled = bool(recipe)
def __init__(self, model, recipe_new, recipe_base = None):
self.enabled = bool(recipe_new)
self.model = model.module if is_parallel(model) else model
self.recipe = recipe
self.manager = ScheduledModifierManager.from_yaml(recipe) if self.enabled else None
if self.enabled:
self.manager = (ScheduledModifierManager.compose_staged(recipe_base, recipe_new)
if recipe_base else ScheduledModifierManager.from_yaml(recipe_new))
else:
self.manager = None
self.logger = None

def state_dict(self):
Expand Down