From 8eaecd23aa4490ce02c5ad1a0872c073cabc3205 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Fri, 25 Mar 2022 13:45:08 -0400 Subject: [PATCH 01/16] SparseML integration --- detect.py | 3 +- export.py | 186 ++++++++++++++++++++++++++--- models/common.py | 48 ++++---- models/yolo.py | 16 ++- train.py | 129 ++++++++++++-------- utils/activations.py | 17 +++ utils/downloads.py | 3 + utils/general.py | 8 +- utils/loggers/wandb/wandb_utils.py | 4 + utils/torch_utils.py | 30 ++++- val.py | 3 +- 11 files changed, 346 insertions(+), 101 deletions(-) diff --git a/detect.py b/detect.py index ccb9fbf5103f..559b3414f506 100644 --- a/detect.py +++ b/detect.py @@ -45,6 +45,7 @@ increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box from utils.torch_utils import select_device, time_sync +from export import load_checkpoint @torch.no_grad() @@ -89,7 +90,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) # Load model device = select_device(device) - model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + model, extras = load_checkpoint(type_='val', weights=weights, device=device) # load FP32 model stride, names, pt = model.stride, model.names, model.pt imgsz = check_img_size(imgsz, s=stride) # check image size diff --git a/export.py b/export.py index 2d4a68e62f89..078b3c6940f0 100644 --- a/export.py +++ b/export.py @@ -43,6 +43,7 @@ """ import argparse +from copy import deepcopy import json import os import platform @@ -57,20 +58,26 @@ import torch.nn as nn from torch.utils.mobile_optimizer import optimize_for_mobile +from sparseml.pytorch.utils import ModuleExporter +from sparseml.pytorch.sparsification.quantization import skip_onnx_input_quantize + FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative -from models.common import Conv +from models.common import Conv, DetectMultiBackend from models.experimental import attempt_load -from models.yolo import Detect +from models.yolo import Detect, Model from utils.activations import SiLU from utils.datasets import LoadImages from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr, - file_size, print_args, url2file) -from utils.torch_utils import select_device + file_size, print_args, url2file, intersect_dicts) +from utils.torch_utils import select_device, torch_distributed_zero_first, is_parallel +from utils.downloads import attempt_download +from utils.sparse import SparseMLWrapper, check_download_sparsezoo_weights + def export_formats(): @@ -118,14 +125,33 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') f = file.with_suffix('.onnx') - torch.onnx.export(model, im, f, verbose=False, opset_version=opset, - training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL, - do_constant_folding=not train, - input_names=['images'], - output_names=['output'], - dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640) - 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85) - } if dynamic else None) + # export through SparseML so quantized and pruned graphs can be corrected + save_dir = f.parent.absolute() + save_name = str(f).split(os.path.sep)[-1] + + # get the number of outputs so we know how to name and change dynamic axes + # nested outputs can be returned if model is exported with dynamic + def _count_outputs(outputs): + count = 0 + if isinstance(outputs, list) or isinstance(outputs, tuple): + for out in outputs: + count += _count_outputs(out) + else: + count += 1 + return count + + outputs = model(im) + num_outputs = _count_outputs(outputs) + input_names = ['input'] + output_names = [f'out_{i}' for i in range(num_outputs)] + dynamic_axes = {k: {0: 'batch'} for k in (input_names + output_names)} if dynamic else None + exporter = ModuleExporter(model, save_dir) + exporter.export_onnx(im, name=save_name, convert_qat=True, + input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) + try: + skip_onnx_input_quantize(f, f) + except: + pass # Checks model_onnx = onnx.load(f) # load onnx model @@ -407,6 +433,128 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): except Exception as e: LOGGER.info(f'\n{prefix} export failure: {e}') +def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): + pickle = not sparseml_wrapper.qat_active(epoch) # qat does not support pickled exports + ckpt_model = deepcopy(model.module if is_parallel(model) else model).float() + yaml = ckpt_model.yaml + if not pickle: + ckpt_model = ckpt_model.state_dict() + + version = 6 if isinstance([module for module in model.model.modules()][1], Conv) else 5 + + return {'epoch': epoch, + 'model': ckpt_model, + 'optimizer': optimizer.state_dict(), + 'yaml': yaml, + 'hyp': model.hyp, + 'version': version, + **ema.state_dict(pickle), + **sparseml_wrapper.state_dict(), + **kwargs} + +def load_checkpoint( + type_, + weights, + device, + cfg=None, + hyp=None, + nc=None, + data=None, + dnn=False, + half = False, + recipe=None, + resume=None, + rank=-1 + ): + with torch_distributed_zero_first(rank): + # download if not found locally or from sparsezoo if stub + weights = attempt_download(weights) or check_download_sparsezoo_weights(weights) + ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple) + else weights, map_location="cpu") # load checkpoint + start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0 + pickled = isinstance(ckpt['model'], nn.Module) + train_type = type_ == 'train' + ensemble_type = type_ == 'ensemble' + val_type = type_ =='val' + + if pickled and ensemble_type: + cfg = None + if ensemble_type: + model = attempt_load(weights, map_location=device) # load ensemble using pickled + state_dict = model.state_dict() + elif val_type: + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + state_dict = model.model.state_dict() + else: + # load model from config and weights + cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \ + (ckpt['model'].yaml if pickled else None) + model = Model(cfg, ch=3, nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc, + anchors=hyp.get('anchors') if hyp else None).to(device) + model_key = 'ema' if (not train_type and 'ema' in ckpt and ckpt['ema']) else 'model' + state_dict = ckpt[model_key].float().state_dict() if pickled else ckpt[model_key] + if val_type: + model = DetectMultiBackend(model=model, device=device, dnn=dnn, data=data, fp16=half) + + # turn gradients for params back on in case they were removed + for p in model.parameters(): + p.requires_grad = True + + # load sparseml recipe for applying pruning and quantization + recipe = recipe or (ckpt['recipe'] if 'recipe' in ckpt else None) + sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe) + exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume + loaded = False + + if not train_type: + # update param names for yolov5x5 models (model.x -> model.model.x) + ''' + if ('version' not in ckpt or ckpt['version'] < 6) and sparseml_wrapper.manager is not None: + for modifier in sparseml_wrapper.manager.pruning_modifiers: + updated_params = [] + for param in modifier.params: + updated_params.append( + "model." + param if (param.startswith('model.') and + not param.startswith('model.model.')) else param + ) + modifier.params = updated_params + ''' + # apply the recipe to create the final state of the model when not training + sparseml_wrapper.apply() + else: + # intialize the recipe for training and restore the weights before if no quantized weights + quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()]) + if not quantized_state_dict: + state_dict = load_state_dict(model, state_dict, train=True, exclude_anchors=exclude_anchors) + loaded = True + sparseml_wrapper.initialize(start_epoch) + + if not loaded: + state_dict = load_state_dict(model, state_dict, train=train_type, exclude_anchors=exclude_anchors) + + model.float() + report = 'Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights) + + return model, { + 'ckpt': ckpt, + 'state_dict': state_dict, + 'start_epoch': start_epoch, + 'sparseml_wrapper': sparseml_wrapper, + 'report': report, + } + + +def load_state_dict(model, state_dict, train, exclude_anchors): + # fix older state_dict names not porting to the new model setup + state_dict = {key if not key.startswith("module.") else key[7:]: val for key, val in state_dict.items()} + + if train: + # load any missing weights from the model + state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=['anchor'] if exclude_anchors else []) + + model.load_state_dict(state_dict, strict=not train) # load + + return state_dict @torch.no_grad() def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' @@ -414,7 +562,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' imgsz=(640, 640), # image (height, width) batch_size=1, # batch size device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu - include=('torchscript', 'onnx'), # include formats + include=('onnx'), # include formats half=False, # FP16 half-precision export inplace=False, # set YOLOv5 Detect() inplace=True train=False, # model.train() mode @@ -430,7 +578,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' topk_per_class=100, # TF.js NMS: topk per class to keep topk_all=100, # TF.js NMS: topk for all classes to keep iou_thres=0.45, # TF.js NMS: IoU threshold - conf_thres=0.25 # TF.js NMS: confidence threshold + conf_thres=0.25, # TF.js NMS: confidence threshold + remove_grid=False, ): t = time.time() include = [x.lower() for x in include] # to lowercase @@ -443,8 +592,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' # Load PyTorch model device = select_device(device) assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0' - model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model - nc, names = model.nc, model.names # number of classes, class names + model, extras = load_checkpoint(type_='ensemble', weights=weights, device=device) # load FP32 model + sparseml_wrapper = extras['sparseml_wrapper'] + nc, names = extras["ckpt"]["nc"], model.names # number of classes, class names # Checks imgsz *= 2 if len(imgsz) == 1 else 1 # expand @@ -469,6 +619,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' m.onnx_dynamic = dynamic if hasattr(m, 'forward_export'): m.forward = m.forward_export # assign custom forward (optional) + model.model[-1].export = not remove_grid # set Detect() layer grid export for _ in range(2): y = model(im) # dry runs @@ -541,6 +692,7 @@ def parse_opt(): parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') + parser.add_argument("--remove-grid", action="store_true", help="remove export of Detect() layer grid") parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx'], help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs') @@ -556,4 +708,4 @@ def main(opt): if __name__ == "__main__": opt = parse_opt() - main(opt) + main(opt) \ No newline at end of file diff --git a/models/common.py b/models/common.py index 115e3c3145ff..e0b783f55033 100644 --- a/models/common.py +++ b/models/common.py @@ -31,7 +31,7 @@ def autopad(k, p=None): # kernel, padding # Pad to 'same' if p is None: - p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad return p @@ -121,7 +121,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu def forward(self, x): y1 = self.cv3(self.m(self.cv1(x))) y2 = self.cv2(x) - return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) class C3(nn.Module): @@ -131,12 +131,12 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1) - self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) - # self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n))) + # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) def forward(self, x): - return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) class C3TR(C3): @@ -194,7 +194,7 @@ def forward(self, x): warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning y1 = self.m(x) y2 = self.m(y1) - return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1)) + return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) class Focus(nn.Module): @@ -205,7 +205,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k # self.contract = Contract(gain=2) def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) - return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1)) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) # return self.conv(self.contract(x)) @@ -219,7 +219,7 @@ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, s def forward(self, x): y = self.cv1(x) - return torch.cat((y, self.cv2(y)), 1) + return torch.cat([y, self.cv2(y)], 1) class GhostBottleneck(nn.Module): @@ -277,7 +277,7 @@ def forward(self, x): class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False): + def __init__(self, weights='yolov5s.pt', model=None, device=torch.device('cpu'), dnn=False, data=None, fp16=False): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript @@ -303,11 +303,11 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, names = yaml.safe_load(f)['names'] # class names if pt: # PyTorch - model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) + model = model or (attempt_load(weights if isinstance(weights, list) else w, map_location=device)) stride = max(int(model.stride.max()), 32) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names model.half() if fp16 else model.float() - self.model = model # explicitly assign for to(), cpu(), cuda(), half() + self.model = model.model # explicitly assign for to(), cpu(), cuda(), half() elif jit: # TorchScript LOGGER.info(f'Loading {w} for TorchScript inference...') extra_files = {'config.txt': ''} # model metadata @@ -527,7 +527,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference if isinstance(imgs, torch.Tensor): # torch - with amp.autocast(autocast): + with amp.autocast(enabled=autocast): return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference # Pre-process @@ -550,19 +550,19 @@ def forward(self, imgs, size=640, augment=False, profile=False): shape1.append([y * g for y in s]) imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape - x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 t.append(time_sync()) - with amp.autocast(autocast): + with amp.autocast(enabled=autocast): # Inference y = self.model(x, augment, profile) # forward t.append(time_sync()) # Post-process - y = non_max_suppression(y if self.dmb else y[0], self.conf, self.iou, self.classes, self.agnostic, - self.multi_label, max_det=self.max_det) # NMS + y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes, + agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det) # NMS for i in range(n): scale_coords(shape1, y[i][:, :4], shape0[i]) @@ -589,7 +589,7 @@ def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms) self.s = shape # inference BCHW shape - def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')): + def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')): crops = [] for i, (im, pred) in enumerate(zip(self.imgs, self.pred)): s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string @@ -606,7 +606,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label, 'im': save_one_box(box, im, file=file, save=save)}) else: # all others - annotator.box_label(box, label if labels else '', color=colors(cls)) + annotator.box_label(box, label, color=colors(cls)) im = annotator.im else: s += '(no detections)' @@ -633,19 +633,19 @@ def print(self): LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t) - def show(self, labels=True): - self.display(show=True, labels=labels) # show results + def show(self): + self.display(show=True) # show results - def save(self, labels=True, save_dir='runs/detect/exp'): + def save(self, save_dir='runs/detect/exp'): save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir - self.display(save=True, labels=labels, save_dir=save_dir) # save results + self.display(save=True, save_dir=save_dir) # save results def crop(self, save=True, save_dir='runs/detect/exp'): save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None return self.display(crop=True, save=save, save_dir=save_dir) # crop results - def render(self, labels=True): - self.display(render=True, labels=labels) # render results + def render(self): + self.display(render=True) # render results return self.imgs def pandas(self): diff --git a/models/yolo.py b/models/yolo.py index 9f4701c49f9d..f08d41ce1585 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -19,6 +19,7 @@ from models.common import * from models.experimental import * +from utils.activations import replace_activations from utils.autoanchor import check_anchor_order from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args from utils.plots import feature_visualization @@ -33,6 +34,7 @@ class Detect(nn.Module): stride = None # strides computed during build onnx_dynamic = False # ONNX export parameter + export = True # onnx export def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer super().__init__() @@ -53,7 +55,7 @@ def forward(self, x): bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() - if not self.training: # inference + if not self.training and self.export: # inference if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) @@ -67,7 +69,7 @@ def forward(self, x): y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) - return x if self.training else (torch.cat(z, 1), x) + return x if self.training or not self.export else (torch.cat(z, 1), x) def _make_grid(self, nx=20, ny=20, i=0): d = self.anchors[i].device @@ -291,7 +293,15 @@ def parse_model(d, ch): # model_dict, input_channels(3) if i == 0: ch = [] ch.append(c2) - return nn.Sequential(*layers), sorted(save) + + model = nn.Sequential(*layers) + + # override all activations in model if provided in config + if 'act' in d: + LOGGER.info(f'overriding activations in model to {d["act"]}') + replace_activations(model, d["act"]) + + return model, sorted(save) if __name__ == '__main__': diff --git a/train.py b/train.py index 60be962d447f..a263d1e9c996 100644 --- a/train.py +++ b/train.py @@ -40,6 +40,7 @@ import val # for end-of-epoch mAP from models.experimental import attempt_load +from export import load_checkpoint, create_checkpoint from models.yolo import Model from utils.autoanchor import check_anchors from utils.autobatch import check_train_batch_size @@ -56,6 +57,7 @@ from utils.metrics import fitness from utils.plots import plot_evolve, plot_labels from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first +from utils.sparse import SparseMLWrapper LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -85,9 +87,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Save run settings if not evolve: with open(save_dir / 'hyp.yaml', 'w') as f: - yaml.safe_dump(hyp, f, sort_keys=False) + yaml.dump(hyp, f, sort_keys=False) with open(save_dir / 'opt.yaml', 'w') as f: - yaml.safe_dump(vars(opt), f, sort_keys=False) + yaml.dump(vars(opt), f, sort_keys=False) # Loggers data_dict = None @@ -105,6 +107,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Config plots = not evolve # create plots cuda = device.type != 'cpu' + half_precision = cuda init_seeds(1 + RANK) with torch_distributed_zero_first(LOCAL_RANK): data_dict = data_dict or check_dataset(data) # check if None @@ -115,20 +118,27 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset # Model - check_suffix(weights, '.pt') # check weights - pretrained = weights.endswith('.pt') + check_suffix(weights, ['.pt', '.pth']) # check weights + pretrained = weights.endswith('.pt') or weights.endswith('.pth') or weights.startswith('zoo:') if pretrained: - with torch_distributed_zero_first(LOCAL_RANK): - weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak - model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create - exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys - csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 - csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect - model.load_state_dict(csd, strict=False) # load - LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report + model, extras = load_checkpoint( + type_ = 'train', + weights=weights, + device=device, + cfg=opt.cfg, + hyp=hyp, + nc=nc, + recipe=opt.recipe, + resume=opt.resume, + rank=LOCAL_RANK + ) + ckpt, state_dict, sparseml_wrapper = extras['ckpt'], extras['state_dict'], extras['sparseml_wrapper'] + LOGGER.info(extras['report']) else: model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + sparseml_wrapper = SparseMLWrapper(model, opt.recipe) + sparseml_wrapper.initialize(start_epoch=0.0) + ckpt = None # Freeze freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze @@ -183,11 +193,22 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) # EMA - ema = ModelEMA(model) if RANK in [-1, 0] else None + ema = ModelEMA(model, enabled=not opt.disable_ema) if RANK in [-1, 0] else None # Resume start_epoch, best_fitness = 0, 0.0 if pretrained: + # Epochs + start_epoch = ckpt['epoch'] + 1 + if opt.resume: + assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) + if epochs < start_epoch: + LOGGER.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % + (weights, ckpt['epoch'], epochs)) + epochs += ckpt['epoch'] # finetune additional epochs + if sparseml_wrapper.qat_active(start_epoch): + ema.enabled = False + # Optimizer if ckpt['optimizer'] is not None: optimizer.load_state_dict(ckpt['optimizer']) @@ -198,15 +219,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) ema.updates = ckpt['updates'] - # Epochs - start_epoch = ckpt['epoch'] + 1 - if resume: - assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' - if epochs < start_epoch: - LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") - epochs += ckpt['epoch'] # finetune additional epochs - - del ckpt, csd + del ckpt # DP mode if cuda and RANK == -1 and torch.cuda.device_count() > 1: @@ -247,7 +260,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Anchors if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) - model.half().float() # pre-reduce anchor precision callbacks.run('on_pretrain_routine_end') @@ -273,15 +285,29 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary last_opt_step = -1 maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) - scheduler.last_epoch = start_epoch - 1 # do not move - scaler = amp.GradScaler(enabled=cuda) + if scheduler: + scheduler.last_epoch = start_epoch - 1 # do not mov + scaler = amp.GradScaler(enabled=half_precision) stopper = EarlyStopping(patience=opt.patience) compute_loss = ComputeLoss(model) # init loss class LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n' f"Logging results to {colorstr('bold', save_dir)}\n" f'Starting training for {epochs} epochs...') + + # SparseML Integration + if RANK in [-1, 0]: + sparseml_wrapper.initialize_loggers(loggers.logger, loggers.tb, loggers.wandb) + scaler = sparseml_wrapper.modify(scaler, optimizer, model, train_loader) + scheduler = sparseml_wrapper.check_lr_override(scheduler, RANK) + epochs = sparseml_wrapper.check_epoch_override(epochs, RANK) + for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ + if sparseml_wrapper.qat_active(epoch): + LOGGER.info('Disabling half precision and EMA, QAT scheduled to run') + half_precision = False + scaler._enabled = False + ema.enabled = False model.train() # Update image weights (optional, single-GPU only) @@ -313,7 +339,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 - x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) + if scheduler: + x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) if 'momentum' in x: x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']]) @@ -326,7 +353,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) # Forward - with amp.autocast(enabled=cuda): + with amp.autocast(enabled=half_precision): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: @@ -345,6 +372,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if ema: ema.update(model) last_opt_step = ni + elif hasattr(scaler, "emulated_step"): + # Call for SparseML integration since the number of steps per epoch can vary + # This keeps the number of steps per epoch equivalent to the number of batches per epoch + # Does not step the scaler or the optimizer + scaler.emulated_step() # Log if RANK in [-1, 0]: @@ -359,7 +391,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Scheduler lr = [x['lr'] for x in optimizer.param_groups] # for loggers - scheduler.step() + if scheduler: + scheduler.step() if RANK in [-1, 0]: # mAP @@ -376,25 +409,23 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary save_dir=save_dir, plots=False, callbacks=callbacks, - compute_loss=compute_loss) + compute_loss=compute_loss, + half=half_precision) # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] - if fi > best_fitness: + if fi > best_fitness or sparseml_wrapper.reset_best(epoch): best_fitness = fi log_vals = list(mloss) + list(results) + lr callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi) # Save model - if (not nosave) or (final_epoch and not evolve): # if save - ckpt = {'epoch': epoch, - 'best_fitness': best_fitness, - 'model': deepcopy(de_parallel(model)).half(), - 'ema': deepcopy(ema.ema).half(), - 'updates': ema.updates, - 'optimizer': optimizer.state_dict(), - 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, - 'date': datetime.now().isoformat()} + if (not opt.nosave) or (final_epoch and not opt.evolve): # if save + ckpt_extras = {'nc': nc, + 'best_fitness': best_fitness, + 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, + 'date': datetime.now().isoformat()} + ckpt = create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras) # Save last, best and delete torch.save(ckpt, last) @@ -422,7 +453,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- if RANK in [-1, 0]: - LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.') + LOGGER.info(f'\n{epochs - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.') for f in last, best: if f.exists(): strip_optimizer(f) # strip optimizers @@ -431,7 +462,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary results, _, _ = val.run(data_dict, batch_size=batch_size // WORLD_SIZE * 2, imgsz=imgsz, - model=attempt_load(f, device).half(), + model=load_checkpoint(type_='ensemble', weights=best, device=device)[0], iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65 single_cls=single_cls, dataloader=val_loader, @@ -440,7 +471,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary verbose=True, plots=True, callbacks=callbacks, - compute_loss=compute_loss) # val best model with plots + compute_loss=compute_loss, # val best model with plots + half=half_precision) if is_coco: callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi) @@ -491,6 +523,9 @@ def parse_opt(known=False): parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option') parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval') parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use') + parser.add_argument('--recipe', type=str, default=None, help='Path to a sparsification recipe, ' + 'see https://github.com/neuralmagic/sparseml for more information') + parser.add_argument('--disable-ema', action='store_true', help='Disable EMA model updates (enabled by default)') opt = parser.parse_known_args()[0] if known else parser.parse_args() return opt @@ -508,7 +543,7 @@ def main(opt, callbacks=Callbacks()): ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f: - opt = argparse.Namespace(**yaml.safe_load(f)) # replace + opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate LOGGER.info(f'Resuming training from {ckpt}') else: @@ -518,8 +553,8 @@ def main(opt, callbacks=Callbacks()): if opt.evolve: if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve opt.project = str(ROOT / 'runs/evolve') - opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume - opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) + opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume + opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run # DDP mode device = select_device(opt.device, batch_size=opt.batch_size) @@ -575,7 +610,7 @@ def main(opt, callbacks=Callbacks()): 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability) with open(opt.hyp, errors='ignore') as f: - hyp = yaml.safe_load(f) # load hyps dict + hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps dict if 'anchors' not in hyp: # anchors commented in hyp.yaml hyp['anchors'] = 3 opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch diff --git a/utils/activations.py b/utils/activations.py index a4ff789cf336..b119d915e54c 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -7,6 +7,23 @@ import torch.nn as nn import torch.nn.functional as F +def is_activation(mod, act_types=None): + if not act_types: + act_types = (nn.ELU, nn.Hardshrink, nn.Hardsigmoid, nn.Hardtanh, nn.Hardswish, nn.LeakyReLU, + nn.LogSigmoid, nn.PReLU, nn.ReLU, nn.ReLU6, nn.RReLU, nn.SELU, nn.CELU, nn.GELU, + nn.Sigmoid, nn.SiLU, nn.Softplus, nn.Softshrink, nn.Softsign, nn.Tanh, nn.Tanhshrink, + SiLU, Hardswish, Mish, MemoryEfficientMish, FReLU) + + return isinstance(mod, act_types) + + +def replace_activations(mod, act, act_types=None): + for name, child in mod.named_children(): + if is_activation(child, act_types): + child_act = act if not isinstance(act, str) else eval(act)() + setattr(mod, name, child_act) + else: + replace_activations(child, act, act_types) # SiLU https://arxiv.org/pdf/1606.08415.pdf ---------------------------------------------------------------------------- class SiLU(nn.Module): # export-friendly version of nn.SiLU() diff --git a/utils/downloads.py b/utils/downloads.py index d7b87cb2cadd..714ffb2a0452 100644 --- a/utils/downloads.py +++ b/utils/downloads.py @@ -42,6 +42,9 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads import *; attempt_download() # Attempt file download if does not exist + if not isinstance(file, (Path, str)) or str(file).startswith("zoo:"): + return + file = Path(str(file).strip().replace("'", '')) if not file.exists(): diff --git a/utils/general.py b/utils/general.py index b0c5e9d69ab7..aeea4f3792c6 100755 --- a/utils/general.py +++ b/utils/general.py @@ -803,9 +803,11 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys x[k] = None x['epoch'] = -1 - x['model'].half() # to FP16 - for p in x['model'].parameters(): - p.requires_grad = False + pickled = isinstance(x['model'], torch.nn.Module) + if pickled: + x['model'].half() # to FP16 + for p in x['model'].parameters(): + p.requires_grad = False torch.save(x, s or f) mb = os.path.getsize(s or f) / 1E6 # filesize LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") diff --git a/utils/loggers/wandb/wandb_utils.py b/utils/loggers/wandb/wandb_utils.py index 786e58a19972..a2c7102bce14 100644 --- a/utils/loggers/wandb/wandb_utils.py +++ b/utils/loggers/wandb/wandb_utils.py @@ -169,6 +169,10 @@ def __init__(self, opt, run_id=None, job_type='Training'): if opt.upload_dataset: if not opt.resume: self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt) + self.wandb_run.config.update({ + 'opt': vars(opt), + 'data_dict': self.wandb_artifact_data_dict + }, allow_val_change=True) if opt.resume: # resume from artifact diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 72f8a0fd1659..02698e656481 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -285,27 +285,47 @@ class ModelEMA: For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """ - def __init__(self, model, decay=0.9999, tau=2000, updates=0): + def __init__(self, model, decay=0.9999, tau=2000, updates=0, enabled=True): # Create EMA - self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA + self._model = model + self._ema = deepcopy(de_parallel(model)).eval() # FP32 EMA # if next(model.parameters()).device.type != 'cpu': # self.ema.half() # FP16 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) - for p in self.ema.parameters(): + self.enabled=enabled + for p in self._ema.parameters(): p.requires_grad_(False) + @property + def ema(self): + if not self.enabled: + return deepcopy(self._model.module if is_parallel(self._model) else self._model).eval() + return self._ema + + def state_dict(self, pickle=True): + ema = deepcopy(self.ema).float() + return { + 'ema': ema if pickle else ema.state_dict(), + 'updates': self.updates, + } + def update(self, model): + self._model = model + if not self.enabled: + return # Update EMA parameters with torch.no_grad(): + msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict self.updates += 1 d = self.decay(self.updates) - msd = de_parallel(model).state_dict() # model state_dict for k, v in self.ema.state_dict().items(): if v.dtype.is_floating_point: + mv = msd[k].detach() v *= d - v += (1 - d) * msd[k].detach() + v += (1. - d) * mv + v *= mv != 0 # preserve pruned parameters in model (equal to 0) def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes diff --git a/val.py b/val.py index 2dd2aec679f9..a7503a50f247 100644 --- a/val.py +++ b/val.py @@ -35,6 +35,7 @@ sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative +from export import load_checkpoint from models.common import DetectMultiBackend from utils.callbacks import Callbacks from utils.datasets import create_dataloader @@ -135,7 +136,7 @@ def run(data, (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Load model - model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + model, extras = load_checkpoint(type_='val', weights=weights, device=device) # load FP32 model stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine imgsz = check_img_size(imgsz, s=stride) # check image size half = model.fp16 # FP16 supported on limited backends with CUDA From 1f7355281cc32b15f01849c97c89e92f24b04feb Mon Sep 17 00:00:00 2001 From: Konstantin Date: Fri, 25 Mar 2022 13:57:29 -0400 Subject: [PATCH 02/16] Add SparseML dependancy --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 96fc9d1a1f32..818b9d30bdfd 100755 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,4 @@ seaborn>=0.11.0 # pycocotools>=2.0 # COCO mAP # roboflow thop # FLOPs computation +sparseml[torch,torchvision]>=0.11 # Pruning and Quantization From 4db2a15cf62ef76756cbb3611a93cfef5f450d3a Mon Sep 17 00:00:00 2001 From: Konstantin Date: Sat, 26 Mar 2022 10:54:09 -0400 Subject: [PATCH 03/16] Update: add missing files --- models_v5.0/yolov5l.yaml | 48 +++++++++++++ models_v5.0/yolov5m.yaml | 48 +++++++++++++ models_v5.0/yolov5s.yaml | 48 +++++++++++++ models_v5.0/yolov5x.yaml | 48 +++++++++++++ utils/sparse.py | 141 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 333 insertions(+) create mode 100644 models_v5.0/yolov5l.yaml create mode 100644 models_v5.0/yolov5m.yaml create mode 100644 models_v5.0/yolov5s.yaml create mode 100644 models_v5.0/yolov5x.yaml create mode 100644 utils/sparse.py diff --git a/models_v5.0/yolov5l.yaml b/models_v5.0/yolov5l.yaml new file mode 100644 index 000000000000..71ebf86e5791 --- /dev/null +++ b/models_v5.0/yolov5l.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models_v5.0/yolov5m.yaml b/models_v5.0/yolov5m.yaml new file mode 100644 index 000000000000..3c749c916246 --- /dev/null +++ b/models_v5.0/yolov5m.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models_v5.0/yolov5s.yaml b/models_v5.0/yolov5s.yaml new file mode 100644 index 000000000000..aca669d60d8b --- /dev/null +++ b/models_v5.0/yolov5s.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models_v5.0/yolov5x.yaml b/models_v5.0/yolov5x.yaml new file mode 100644 index 000000000000..d3babdf7baf0 --- /dev/null +++ b/models_v5.0/yolov5x.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.33 # model depth multiple +width_multiple: 1.25 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/utils/sparse.py b/utils/sparse.py new file mode 100644 index 000000000000..73a3e29e5259 --- /dev/null +++ b/utils/sparse.py @@ -0,0 +1,141 @@ +import math + +from sparsezoo import Zoo +from sparseml.pytorch.optim import ScheduledModifierManager +from sparseml.pytorch.utils import SparsificationGroupLogger + +from utils.torch_utils import is_parallel + + +def _get_model_framework_file(model, path): + transfer_request = 'recipe_type=transfer' in path + checkpoint_available = any([file.checkpoint for file in model.framework_files]) + final_available = any([not file.checkpoint for file in model.framework_files]) + + if transfer_request and checkpoint_available: + # checkpoints are saved for transfer learning use cases, + # return checkpoint if avaiable and requested + return [file for file in model.framework_files if file.checkpoint][0] + elif final_available: + # default to returning final state, if available + return [file for file in model.framework_files if not file.checkpoint][0] + + raise ValueError(f"Could not find a valid framework file for {path}") + + +def check_download_sparsezoo_weights(path): + if isinstance(path, str): + if path.startswith("zoo:"): + # load model from the SparseZoo and override the path with the new download + model = Zoo.load_model_from_stub(path) + file = _get_model_framework_file(model, path) + path = file.downloaded_path() + + return path + + if isinstance(path, list): + return [check_download_sparsezoo_weights(p) for p in path] + + return path + + +class SparseMLWrapper(object): + def __init__(self, model, recipe): + self.enabled = bool(recipe) + self.model = model.module if is_parallel(model) else model + self.recipe = recipe + self.manager = ScheduledModifierManager.from_yaml(recipe) if self.enabled else None + self.logger = None + + def state_dict(self): + return { + 'recipe': str(self.manager) if self.enabled else None, + } + + def apply(self): + if not self.enabled: + return + + self.manager.apply(self.model) + + def initialize(self, start_epoch): + if not self.enabled: + return + + self.manager.initialize(self.model, start_epoch) + + def initialize_loggers(self, logger, tb_writer, wandb_logger): + self.logger = logger + + if not self.enabled: + return + + def _logging_lambda(tag, value, values, step, wall_time, level): + if not wandb_logger or not wandb_logger.wandb: + return + + if value is not None: + wandb_logger.log({tag: value}) + + if values: + wandb_logger.log(values) + + self.manager.initialize_loggers([ + SparsificationGroupLogger( + lambda_func=_logging_lambda, + tensorboard=tb_writer, + ) + ]) + + if wandb_logger and wandb_logger.wandb: + artifact = wandb_logger.wandb.Artifact('recipe', type='recipe') + with artifact.new_file('recipe.yaml') as file: + file.write(str(self.manager)) + wandb_logger.wandb.log_artifact(artifact) + + def modify(self, scaler, optimizer, model, dataloader): + if not self.enabled: + return scaler + + return self.manager.modify(model, optimizer, steps_per_epoch=len(dataloader), wrap_optim=scaler) + + def check_lr_override(self, scheduler, rank): + # Override lr scheduler if recipe makes any LR updates + if self.enabled and self.manager.learning_rate_modifiers: + if rank in [0,-1]: + self.logger.info('Disabling LR scheduler, managing LR using SparseML recipe') + scheduler = None + + return scheduler + + def check_epoch_override(self, epochs, rank): + # Override num epochs if recipe explicitly modifies epoch range + if self.enabled and self.manager.epoch_modifiers and self.manager.max_epochs: + if rank in [0,-1]: + self.logger.info(f'Overriding number of epochs from SparseML manager to {epochs}') + epochs = self.manager.max_epochs or epochs # override num_epochs + + return epochs + + def qat_active(self, epoch): + if not self.enabled or not self.manager.quantization_modifiers: + return False + + qat_start = min([mod.start_epoch for mod in self.manager.quantization_modifiers]) + + return qat_start < epoch + 1 + + def reset_best(self, epoch): + if not self.enabled: + return False + + # if pruning is active or quantization just started, need to reset best checkpoint + # this is in case the pruned and/or quantized model do not fully recover + pruning_start = math.floor(max([mod.start_epoch for mod in self.manager.pruning_modifiers])) \ + if self.manager.pruning_modifiers else -1 + pruning_end = math.ceil(max([mod.end_epoch for mod in self.manager.pruning_modifiers])) \ + if self.manager.pruning_modifiers else -1 + qat_start = math.floor(max([mod.start_epoch for mod in self.manager.quantization_modifiers])) \ + if self.manager.quantization_modifiers else -1 + + return (pruning_start <= epoch <= pruning_end) or epoch == qat_start \ No newline at end of file From 70fb4cd287950f49b343a6849eed43fe7668e1e4 Mon Sep 17 00:00:00 2001 From: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com> Date: Wed, 30 Mar 2022 15:45:17 +0100 Subject: [PATCH 04/16] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 818b9d30bdfd..cf72a9ea4165 100755 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,4 @@ seaborn>=0.11.0 # pycocotools>=2.0 # COCO mAP # roboflow thop # FLOPs computation -sparseml[torch,torchvision]>=0.11 # Pruning and Quantization +sparseml[torch,torchvision]>=0.12 # Pruning and Quantization From bf225ec63be543b30e70e5fb344b3101d368a4fa Mon Sep 17 00:00:00 2001 From: Konstantin Date: Wed, 30 Mar 2022 13:36:17 -0400 Subject: [PATCH 05/16] Update: sparseml-nightly support --- requirements.txt | 2 +- utils/general.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 818b9d30bdfd..cf72a9ea4165 100755 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,4 @@ seaborn>=0.11.0 # pycocotools>=2.0 # COCO mAP # roboflow thop # FLOPs computation -sparseml[torch,torchvision]>=0.11 # Pruning and Quantization +sparseml[torch,torchvision]>=0.12 # Pruning and Quantization diff --git a/utils/general.py b/utils/general.py index aeea4f3792c6..dcdbf95ddca1 100755 --- a/utils/general.py +++ b/utils/general.py @@ -319,6 +319,10 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta n = 0 # number of packages updates for r in requirements: + if r.startswith("sparseml"): + version = r.split("sparseml")[1] + if pkg.working_set.find(pkg.Requirement("sparseml-nightly" + version)): + continue try: pkg.require(r) except Exception: # DistributionNotFound or VersionConflict if requirements not met From 6d3667abd8e8c2c2edacf74129a9f87b169c38d6 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Wed, 30 Mar 2022 14:41:41 -0400 Subject: [PATCH 06/16] Update: remove model versioning --- export.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/export.py b/export.py index 078b3c6940f0..ab609664a2b2 100644 --- a/export.py +++ b/export.py @@ -440,14 +440,11 @@ def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): if not pickle: ckpt_model = ckpt_model.state_dict() - version = 6 if isinstance([module for module in model.model.modules()][1], Conv) else 5 - return {'epoch': epoch, 'model': ckpt_model, 'optimizer': optimizer.state_dict(), 'yaml': yaml, 'hyp': model.hyp, - 'version': version, **ema.state_dict(pickle), **sparseml_wrapper.state_dict(), **kwargs} From 28579c88fc7649844923f12caba0c990047c959a Mon Sep 17 00:00:00 2001 From: Konstantin Date: Tue, 5 Apr 2022 12:46:33 -0400 Subject: [PATCH 07/16] Partial update for multi-stage recipes --- export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/export.py b/export.py index ab609664a2b2..0c5bf9c0c1cf 100644 --- a/export.py +++ b/export.py @@ -498,7 +498,7 @@ def load_checkpoint( p.requires_grad = True # load sparseml recipe for applying pruning and quantization - recipe = recipe or (ckpt['recipe'] if 'recipe' in ckpt else None) + recipe = (ckpt['recipe'] if ('recipe' in ckpt) else None) if resume else recipe sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe) exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume loaded = False From e5999d577cd4172bea7ac78cca00020d467c74f5 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Wed, 6 Apr 2022 13:12:16 -0400 Subject: [PATCH 08/16] Update: multi-stage recipe support --- export.py | 17 +++-------------- utils/sparse.py | 11 +++++++---- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/export.py b/export.py index 0c5bf9c0c1cf..5af8553f04b7 100644 --- a/export.py +++ b/export.py @@ -498,24 +498,13 @@ def load_checkpoint( p.requires_grad = True # load sparseml recipe for applying pruning and quantization - recipe = (ckpt['recipe'] if ('recipe' in ckpt) else None) if resume else recipe - sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe) + recipe_new = (ckpt['recipe'] if ('recipe' in ckpt) else None) if resume else recipe + recipe_base = None if resume else ckpt['recipe'] + sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe_new, recipe_base) exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume loaded = False if not train_type: - # update param names for yolov5x5 models (model.x -> model.model.x) - ''' - if ('version' not in ckpt or ckpt['version'] < 6) and sparseml_wrapper.manager is not None: - for modifier in sparseml_wrapper.manager.pruning_modifiers: - updated_params = [] - for param in modifier.params: - updated_params.append( - "model." + param if (param.startswith('model.') and - not param.startswith('model.model.')) else param - ) - modifier.params = updated_params - ''' # apply the recipe to create the final state of the model when not training sparseml_wrapper.apply() else: diff --git a/utils/sparse.py b/utils/sparse.py index 73a3e29e5259..95652045f21f 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -40,11 +40,14 @@ def check_download_sparsezoo_weights(path): class SparseMLWrapper(object): - def __init__(self, model, recipe): - self.enabled = bool(recipe) + def __init__(self, model, recipe_new, recipe_base = None): + self.enabled = bool(recipe_new) self.model = model.module if is_parallel(model) else model - self.recipe = recipe - self.manager = ScheduledModifierManager.from_yaml(recipe) if self.enabled else None + if self.enabled: + self.manager = (ScheduledModifierManager.compose_staged(recipe_base, recipe_new) + if recipe_base else ScheduledModifierManager.from_yaml(recipe_new)) + else: + self.manager = None self.logger = None def state_dict(self): From 3218a78be333cde71364d639fae9c2bdbe35e1d1 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Wed, 6 Apr 2022 16:04:15 -0400 Subject: [PATCH 09/16] Update: remove sparseml dep --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cf72a9ea4165..96fc9d1a1f32 100755 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,3 @@ seaborn>=0.11.0 # pycocotools>=2.0 # COCO mAP # roboflow thop # FLOPs computation -sparseml[torch,torchvision]>=0.12 # Pruning and Quantization From bbbcf6b09df95eec33015bec7e2bef8f89a9ea5b Mon Sep 17 00:00:00 2001 From: Konstantin Date: Wed, 6 Apr 2022 18:08:09 -0400 Subject: [PATCH 10/16] Fix: multi-stage recipe handeling --- export.py | 17 +++++++++++++---- train.py | 5 +++-- utils/sparse.py | 13 ++++++------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/export.py b/export.py index 5af8553f04b7..394779b77b22 100644 --- a/export.py +++ b/export.py @@ -433,8 +433,10 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): except Exception as e: LOGGER.info(f'\n{prefix} export failure: {e}') -def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): +def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, additional_recipe, **kwargs): pickle = not sparseml_wrapper.qat_active(epoch) # qat does not support pickled exports + if additional_recipe is not None: + sparseml_wrapper.add_stage(additional_recipe) ckpt_model = deepcopy(model.module if is_parallel(model) else model).float() yaml = ckpt_model.yaml if not pickle: @@ -498,9 +500,13 @@ def load_checkpoint( p.requires_grad = True # load sparseml recipe for applying pruning and quantization - recipe_new = (ckpt['recipe'] if ('recipe' in ckpt) else None) if resume else recipe - recipe_base = None if resume else ckpt['recipe'] - sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe_new, recipe_base) + additional_recipe = None + if resume: + recipe = ckpt['recipe'] if ('recipe' in ckpt) else None + elif ckpt['recipe'] or recipe: + recipe, additional_recipe = (ckpt['recipe'], recipe) if (ckpt['recipe'] and recipe) else ((ckpt['recipe'] or recipe), None) + + sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe) exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume loaded = False @@ -513,6 +519,8 @@ def load_checkpoint( if not quantized_state_dict: state_dict = load_state_dict(model, state_dict, train=True, exclude_anchors=exclude_anchors) loaded = True + if not resume: + start_epoch = sparseml_wrapper.manager.max_epochs + 1 sparseml_wrapper.initialize(start_epoch) if not loaded: @@ -527,6 +535,7 @@ def load_checkpoint( 'start_epoch': start_epoch, 'sparseml_wrapper': sparseml_wrapper, 'report': report, + 'additional_recipe': additional_recipe } diff --git a/train.py b/train.py index a263d1e9c996..665d856717e5 100644 --- a/train.py +++ b/train.py @@ -132,7 +132,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary resume=opt.resume, rank=LOCAL_RANK ) - ckpt, state_dict, sparseml_wrapper = extras['ckpt'], extras['state_dict'], extras['sparseml_wrapper'] + ckpt, state_dict, sparseml_wrapper, start_epoch = extras['ckpt'], extras['state_dict'], extras['sparseml_wrapper'], extras['start_epoch'] LOGGER.info(extras['report']) else: model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create @@ -424,7 +424,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ckpt_extras = {'nc': nc, 'best_fitness': best_fitness, 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, - 'date': datetime.now().isoformat()} + 'date': datetime.now().isoformat(), + 'additional_recipe': extras["additional_recipe"]} ckpt = create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras) # Save last, best and delete diff --git a/utils/sparse.py b/utils/sparse.py index 95652045f21f..9530283e7370 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -40,14 +40,10 @@ def check_download_sparsezoo_weights(path): class SparseMLWrapper(object): - def __init__(self, model, recipe_new, recipe_base = None): - self.enabled = bool(recipe_new) + def __init__(self, model, recipe): + self.enabled = bool(recipe) self.model = model.module if is_parallel(model) else model - if self.enabled: - self.manager = (ScheduledModifierManager.compose_staged(recipe_base, recipe_new) - if recipe_base else ScheduledModifierManager.from_yaml(recipe_new)) - else: - self.manager = None + self.manager = ScheduledModifierManager.from_yaml(recipe) if self.enabled else None self.logger = None def state_dict(self): @@ -102,6 +98,9 @@ def modify(self, scaler, optimizer, model, dataloader): return self.manager.modify(model, optimizer, steps_per_epoch=len(dataloader), wrap_optim=scaler) + def add_stage(self, additional_recipe): + self.manager = ScheduledModifierManager.compose_staged(self.manager, additional_recipe) + def check_lr_override(self, scheduler, rank): # Override lr scheduler if recipe makes any LR updates if self.enabled and self.manager.learning_rate_modifiers: From 140ee49cc3138b0521fa771e2a5ab6eb58c9980f Mon Sep 17 00:00:00 2001 From: Konstantin Date: Thu, 7 Apr 2022 16:34:39 -0400 Subject: [PATCH 11/16] Fix: multi stage support --- export.py | 21 +++++++-------------- requirements.txt | 1 + train.py | 15 ++++++--------- utils/loggers/__init__.py | 5 ++++- utils/sparse.py | 30 +++++++++++++++++++----------- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/export.py b/export.py index 394779b77b22..ce377c758c04 100644 --- a/export.py +++ b/export.py @@ -433,10 +433,8 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): except Exception as e: LOGGER.info(f'\n{prefix} export failure: {e}') -def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, additional_recipe, **kwargs): +def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): pickle = not sparseml_wrapper.qat_active(epoch) # qat does not support pickled exports - if additional_recipe is not None: - sparseml_wrapper.add_stage(additional_recipe) ckpt_model = deepcopy(model.module if is_parallel(model) else model).float() yaml = ckpt_model.yaml if not pickle: @@ -500,27 +498,23 @@ def load_checkpoint( p.requires_grad = True # load sparseml recipe for applying pruning and quantization - additional_recipe = None + checkpoint_recipe = None if resume: - recipe = ckpt['recipe'] if ('recipe' in ckpt) else None + train_recipe = ckpt['recipe'] if ('recipe' in ckpt) else None elif ckpt['recipe'] or recipe: - recipe, additional_recipe = (ckpt['recipe'], recipe) if (ckpt['recipe'] and recipe) else ((ckpt['recipe'] or recipe), None) + train_recipe, checkpoint_recipe = recipe, ckpt['recipe'] - sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, recipe) + sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, checkpoint_recipe, train_recipe) exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume loaded = False - if not train_type: - # apply the recipe to create the final state of the model when not training - sparseml_wrapper.apply() - else: + sparseml_wrapper.apply(ckpt['epoch'] if 'epoch' in ckpt else 0) + if train_type: # intialize the recipe for training and restore the weights before if no quantized weights quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()]) if not quantized_state_dict: state_dict = load_state_dict(model, state_dict, train=True, exclude_anchors=exclude_anchors) loaded = True - if not resume: - start_epoch = sparseml_wrapper.manager.max_epochs + 1 sparseml_wrapper.initialize(start_epoch) if not loaded: @@ -535,7 +529,6 @@ def load_checkpoint( 'start_epoch': start_epoch, 'sparseml_wrapper': sparseml_wrapper, 'report': report, - 'additional_recipe': additional_recipe } diff --git a/requirements.txt b/requirements.txt index 96fc9d1a1f32..ab1c44f64132 100755 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,4 @@ seaborn>=0.11.0 # pycocotools>=2.0 # COCO mAP # roboflow thop # FLOPs computation +sparseml[torch, torchvision] >= 0.12 \ No newline at end of file diff --git a/train.py b/train.py index 665d856717e5..e8ed8984d710 100644 --- a/train.py +++ b/train.py @@ -136,8 +136,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary LOGGER.info(extras['report']) else: model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create - sparseml_wrapper = SparseMLWrapper(model, opt.recipe) - sparseml_wrapper.initialize(start_epoch=0.0) + sparseml_wrapper = SparseMLWrapper(model, None, opt.recipe) + sparseml_wrapper.initialize(start_epoch=0) ckpt = None # Freeze @@ -196,16 +196,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ema = ModelEMA(model, enabled=not opt.disable_ema) if RANK in [-1, 0] else None # Resume - start_epoch, best_fitness = 0, 0.0 + start_epoch, best_fitness = sparseml_wrapper.start_epoch, 0.0 if pretrained: - # Epochs - start_epoch = ckpt['epoch'] + 1 if opt.resume: assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) if epochs < start_epoch: LOGGER.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % - (weights, ckpt['epoch'], epochs)) - epochs += ckpt['epoch'] # finetune additional epochs + (weights, start_epoch-1, epochs)) + epochs += start_epoch # finetune additional epochs if sparseml_wrapper.qat_active(start_epoch): ema.enabled = False @@ -424,8 +422,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ckpt_extras = {'nc': nc, 'best_fitness': best_fitness, 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, - 'date': datetime.now().isoformat(), - 'additional_recipe': extras["additional_recipe"]} + 'date': datetime.now().isoformat()} ckpt = create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras) # Save last, best and delete diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index ff6722ecd48a..3b2230c02a14 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -87,7 +87,10 @@ def on_train_batch_end(self, ni, model, imgs, targets, paths, plots, sync_bn): if not sync_bn: # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754 with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress jit trace warning - self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) + try: + self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) + except Exception: + warnings.warn("Couldn't create quantized graph for Tensorboard") if ni < 3: f = self.save_dir / f'train_batch{ni}.jpg' # filename Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() diff --git a/utils/sparse.py b/utils/sparse.py index 9530283e7370..d624554a9ca6 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -40,28 +40,39 @@ def check_download_sparsezoo_weights(path): class SparseMLWrapper(object): - def __init__(self, model, recipe): - self.enabled = bool(recipe) + def __init__(self, model, checkpoint_recipe, train_recipe): + self.enabled = bool(checkpoint_recipe or train_recipe) self.model = model.module if is_parallel(model) else model - self.manager = ScheduledModifierManager.from_yaml(recipe) if self.enabled else None + self.checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint_recipe) if checkpoint_recipe else None + self.manager = ScheduledModifierManager.from_yaml(train_recipe) if train_recipe else None self.logger = None + self.start_epoch = None def state_dict(self): + if self.checkpoint_manager: + manager = ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager) + else: + manager = self.manager return { - 'recipe': str(self.manager) if self.enabled else None, + 'recipe': str(manager) if self.enabled else None, } - def apply(self): + def apply(self, epoch): if not self.enabled: return - self.manager.apply(self.model) + if epoch < 0: + epoch = math.inf + + if self.checkpoint_manager: + self.checkpoint_manager.apply_structure(self.model, epoch) def initialize(self, start_epoch): if not self.enabled: return - + self.manager.initialize(self.model, start_epoch) + self.start_epoch = start_epoch def initialize_loggers(self, logger, tb_writer, wandb_logger): self.logger = logger @@ -98,9 +109,6 @@ def modify(self, scaler, optimizer, model, dataloader): return self.manager.modify(model, optimizer, steps_per_epoch=len(dataloader), wrap_optim=scaler) - def add_stage(self, additional_recipe): - self.manager = ScheduledModifierManager.compose_staged(self.manager, additional_recipe) - def check_lr_override(self, scheduler, rank): # Override lr scheduler if recipe makes any LR updates if self.enabled and self.manager.learning_rate_modifiers: @@ -115,7 +123,7 @@ def check_epoch_override(self, epochs, rank): if self.enabled and self.manager.epoch_modifiers and self.manager.max_epochs: if rank in [0,-1]: self.logger.info(f'Overriding number of epochs from SparseML manager to {epochs}') - epochs = self.manager.max_epochs or epochs # override num_epochs + epochs = self.manager.max_epochs + self.start_epoch or epochs # override num_epochs return epochs From 912040caa773d70f247fc0c2ff9e385145f94258 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Fri, 8 Apr 2022 06:26:44 -0400 Subject: [PATCH 12/16] Fix: non-recipe runs --- export.py | 3 +-- train.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/export.py b/export.py index ce377c758c04..bdf579768908 100644 --- a/export.py +++ b/export.py @@ -498,7 +498,7 @@ def load_checkpoint( p.requires_grad = True # load sparseml recipe for applying pruning and quantization - checkpoint_recipe = None + checkpoint_recipe = train_recipe = None if resume: train_recipe = ckpt['recipe'] if ('recipe' in ckpt) else None elif ckpt['recipe'] or recipe: @@ -526,7 +526,6 @@ def load_checkpoint( return model, { 'ckpt': ckpt, 'state_dict': state_dict, - 'start_epoch': start_epoch, 'sparseml_wrapper': sparseml_wrapper, 'report': report, } diff --git a/train.py b/train.py index e8ed8984d710..738155ad1f77 100644 --- a/train.py +++ b/train.py @@ -132,7 +132,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary resume=opt.resume, rank=LOCAL_RANK ) - ckpt, state_dict, sparseml_wrapper, start_epoch = extras['ckpt'], extras['state_dict'], extras['sparseml_wrapper'], extras['start_epoch'] + ckpt, state_dict, sparseml_wrapper = extras['ckpt'], extras['state_dict'], extras['sparseml_wrapper'] LOGGER.info(extras['report']) else: model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create @@ -196,7 +196,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary ema = ModelEMA(model, enabled=not opt.disable_ema) if RANK in [-1, 0] else None # Resume - start_epoch, best_fitness = sparseml_wrapper.start_epoch, 0.0 + start_epoch = sparseml_wrapper.start_epoch or 0 + best_fitness = 0.0 if pretrained: if opt.resume: assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) From a8dfa0f73054ba3250345182c571972947e84241 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Fri, 8 Apr 2022 07:17:02 -0400 Subject: [PATCH 13/16] Add: legacy hyperparam files --- data/hyps/hyp.finetune.yaml | 38 +++++++++++++++++++++++++++++++++++++ data/hyps/hyp.scratch.yaml | 33 ++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 data/hyps/hyp.finetune.yaml create mode 100644 data/hyps/hyp.scratch.yaml diff --git a/data/hyps/hyp.finetune.yaml b/data/hyps/hyp.finetune.yaml new file mode 100644 index 000000000000..1b84cff95c2c --- /dev/null +++ b/data/hyps/hyp.finetune.yaml @@ -0,0 +1,38 @@ +# Hyperparameters for VOC finetuning +# python train.py --batch 64 --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50 +# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials + + +# Hyperparameter Evolution Results +# Generations: 306 +# P R mAP.5 mAP.5:.95 box obj cls +# Metrics: 0.6 0.936 0.896 0.684 0.0115 0.00805 0.00146 + +lr0: 0.0032 +lrf: 0.12 +momentum: 0.843 +weight_decay: 0.00036 +warmup_epochs: 2.0 +warmup_momentum: 0.5 +warmup_bias_lr: 0.05 +box: 0.0296 +cls: 0.243 +cls_pw: 0.631 +obj: 0.301 +obj_pw: 0.911 +iou_t: 0.2 +anchor_t: 2.91 +# anchors: 3.63 +fl_gamma: 0.0 +hsv_h: 0.0138 +hsv_s: 0.664 +hsv_v: 0.464 +degrees: 0.373 +translate: 0.245 +scale: 0.898 +shear: 0.602 +perspective: 0.0 +flipud: 0.00856 +fliplr: 0.5 +mosaic: 1.0 +mixup: 0.243 diff --git a/data/hyps/hyp.scratch.yaml b/data/hyps/hyp.scratch.yaml new file mode 100644 index 000000000000..44f26b6658ae --- /dev/null +++ b/data/hyps/hyp.scratch.yaml @@ -0,0 +1,33 @@ +# Hyperparameters for COCO training from scratch +# python train.py --batch 40 --cfg yolov5m.yaml --weights '' --data coco.yaml --img 640 --epochs 300 +# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials + + +lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.0005 # optimizer weight decay 5e-4 +warmup_epochs: 3.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.1 # warmup initial bias lr +box: 0.05 # box loss gain +cls: 0.5 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 1.0 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.20 # IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +# anchors: 3 # anchors per output layer (0 to ignore) +fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.015 # image HSV-Hue augmentation (fraction) +hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.4 # image HSV-Value augmentation (fraction) +degrees: 0.0 # image rotation (+/- deg) +translate: 0.1 # image translation (+/- fraction) +scale: 0.5 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 1.0 # image mosaic (probability) +mixup: 0.0 # image mixup (probability) From 7bdf2e66fda64556cd9b91bb8bea42fc21346635 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Fri, 8 Apr 2022 07:43:05 -0400 Subject: [PATCH 14/16] Fix: add copy-paste to hyps --- data/hyps/hyp.finetune.yaml | 1 + data/hyps/hyp.scratch.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/data/hyps/hyp.finetune.yaml b/data/hyps/hyp.finetune.yaml index 1b84cff95c2c..3aa1923f78a6 100644 --- a/data/hyps/hyp.finetune.yaml +++ b/data/hyps/hyp.finetune.yaml @@ -36,3 +36,4 @@ flipud: 0.00856 fliplr: 0.5 mosaic: 1.0 mixup: 0.243 +copy_paste: 0.0 diff --git a/data/hyps/hyp.scratch.yaml b/data/hyps/hyp.scratch.yaml index 44f26b6658ae..e10b9893dd50 100644 --- a/data/hyps/hyp.scratch.yaml +++ b/data/hyps/hyp.scratch.yaml @@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability) fliplr: 0.5 # image flip left-right (probability) mosaic: 1.0 # image mosaic (probability) mixup: 0.0 # image mixup (probability) +copy_paste: 0.0 From 20f6f91ed61e8bc37c30e86229a3677b8fa5f2fb Mon Sep 17 00:00:00 2001 From: Konstantin Date: Fri, 8 Apr 2022 09:39:54 -0400 Subject: [PATCH 15/16] Fix: nit --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ab1c44f64132..36f39017d6af 100755 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,4 @@ seaborn>=0.11.0 # pycocotools>=2.0 # COCO mAP # roboflow thop # FLOPs computation -sparseml[torch, torchvision] >= 0.12 \ No newline at end of file +sparseml[torch,torchvision] >= 0.12 \ No newline at end of file From 5eadf3a40dbbf413caff3fa492dc7a31198e7d0e Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 8 Apr 2022 14:47:49 -0400 Subject: [PATCH 16/16] apply structure fixes --- export.py | 2 +- utils/sparse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index bdf579768908..f489aaa28f07 100644 --- a/export.py +++ b/export.py @@ -508,7 +508,7 @@ def load_checkpoint( exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume loaded = False - sparseml_wrapper.apply(ckpt['epoch'] if 'epoch' in ckpt else 0) + sparseml_wrapper.apply_checkpoint_structure(float("inf")) if train_type: # intialize the recipe for training and restore the weights before if no quantized weights quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()]) diff --git a/utils/sparse.py b/utils/sparse.py index d624554a9ca6..59b4640756f2 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -57,7 +57,7 @@ def state_dict(self): 'recipe': str(manager) if self.enabled else None, } - def apply(self, epoch): + def apply_checkpoint_structure(self, epoch): if not self.enabled: return