Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SparseML Integration to V6.1 #26

Merged
merged 17 commits into from
Apr 8, 2022
Merged
Prev Previous commit
Next Next commit
Fix: non-recipe runs
  • Loading branch information
KSGulin committed Apr 8, 2022
commit 912040caa773d70f247fc0c2ff9e385145f94258
3 changes: 1 addition & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def load_checkpoint(
p.requires_grad = True

# load sparseml recipe for applying pruning and quantization
checkpoint_recipe = None
checkpoint_recipe = train_recipe = None
if resume:
train_recipe = ckpt['recipe'] if ('recipe' in ckpt) else None
elif ckpt['recipe'] or recipe:
Expand Down Expand Up @@ -526,7 +526,6 @@ def load_checkpoint(
return model, {
'ckpt': ckpt,
'state_dict': state_dict,
'start_epoch': start_epoch,
'sparseml_wrapper': sparseml_wrapper,
'report': report,
}
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
resume=opt.resume,
rank=LOCAL_RANK
)
ckpt, state_dict, sparseml_wrapper, start_epoch = extras['ckpt'], extras['state_dict'], extras['sparseml_wrapper'], extras['start_epoch']
ckpt, state_dict, sparseml_wrapper = extras['ckpt'], extras['state_dict'], extras['sparseml_wrapper']
LOGGER.info(extras['report'])
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
Expand Down Expand Up @@ -196,7 +196,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema = ModelEMA(model, enabled=not opt.disable_ema) if RANK in [-1, 0] else None

# Resume
start_epoch, best_fitness = sparseml_wrapper.start_epoch, 0.0
start_epoch = sparseml_wrapper.start_epoch or 0
best_fitness = 0.0
if pretrained:
if opt.resume:
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
Expand Down