diff --git a/detect.py b/detect.py index ff8e32acbaed..70c52dc5214b 100644 --- a/detect.py +++ b/detect.py @@ -89,7 +89,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval() elif onnx: if dnn: - # check_requirements(('opencv-python>=4.5.4',)) + check_requirements(('opencv-python>=4.5.4',)) net = cv2.dnn.readNetFromONNX(w) else: check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) diff --git a/train.py b/train.py index da7346be77ab..d83f3cd1863c 100644 --- a/train.py +++ b/train.py @@ -36,6 +36,7 @@ from models.experimental import attempt_load from models.yolo import Model from utils.autoanchor import check_anchors +from utils.autobatch import check_train_batch_size from utils.datasets import create_dataloader from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \ @@ -131,6 +132,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary print(f'freezing {k}') v.requires_grad = False + # Image size + gs = max(int(model.stride.max()), 32) # grid size (max stride) + imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple + + # Batch size + if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size + batch_size = check_train_batch_size(model, imgsz) + # Optimizer nbs = 64 # nominal batch size accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing @@ -190,11 +199,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary del ckpt, csd - # Image sizes - gs = max(int(model.stride.max()), 32) # grid size (max stride) - nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) - imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple - # DP mode if cuda and RANK == -1 and torch.cuda.device_count() > 1: logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n' @@ -242,6 +246,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) # Model parameters + nl = model.model[-1].nl # number of detection layers (to scale hyps) hyp['box'] *= 3. / nl # scale to layers hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers @@ -440,7 +445,7 @@ def parse_opt(known=False): parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path') parser.add_argument('--epochs', type=int, default=300) - parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') + parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch') parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') diff --git a/tutorial.ipynb b/tutorial.ipynb index 421ddbeaa15f..47c44251b5ab 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -505,7 +505,7 @@ "id": "eyTZYGgRjnMc" }, "source": [ - "## COCO val2017\n", + "## COCO val\n", "Download [COCO val 2017](https://github.com/ultralytics/yolov5/blob/74b34872fdf41941cddcf243951cdb090fbac17b/data/coco.yaml#L14) dataset (1GB - 5000 images), and test model accuracy." ] }, @@ -533,8 +533,8 @@ "outputId": "7e6f5c96-c819-43e1-cd03-d3b9878cf8de" }, "source": [ - "# Download COCO val2017\n", - "torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017val.zip', 'tmp.zip')\n", + "# Download COCO val\n", + "torch.hub.download_url_to_file('https://ultralytics.com/assets/coco2017val.zip', 'tmp.zip')\n", "!unzip -q tmp.zip -d ../datasets && rm tmp.zip" ], "execution_count": null, @@ -567,7 +567,7 @@ "outputId": "3dd0e2fc-aecf-4108-91b1-6392da1863cb" }, "source": [ - "# Run YOLOv5x on COCO val2017\n", + "# Run YOLOv5x on COCO val\n", "!python val.py --weights yolov5x.pt --data coco.yaml --img 640 --iou 0.65 --half" ], "execution_count": null, @@ -627,7 +627,7 @@ "id": "rc_KbFk0juX2" }, "source": [ - "## COCO test-dev2017\n", + "## COCO test\n", "Download [COCO test2017](https://github.com/ultralytics/yolov5/blob/74b34872fdf41941cddcf243951cdb090fbac17b/data/coco.yaml#L15) dataset (7GB - 40,000 images), to test model accuracy on test-dev set (**20,000 images, no labels**). Results are saved to a `*.json` file which should be **zipped** and submitted to the evaluation server at https://competitions.codalab.org/competitions/20794." ] }, @@ -638,10 +638,9 @@ }, "source": [ "# Download COCO test-dev2017\n", - "torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels.zip', 'tmp.zip')\n", - "!unzip -q tmp.zip -d ../ && rm tmp.zip # unzip labels\n", - "!f=\"test2017.zip\" && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f && rm $f # 7GB, 41k images\n", - "%mv ./test2017 ../coco/images # move to /coco" + "torch.hub.download_url_to_file('https://ultralytics.com/assets/coco2017labels.zip', 'tmp.zip')\n", + "!unzip -q tmp.zip -d ../datasets && rm tmp.zip\n", + "!f=\"test2017.zip\" && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f -d ../datasets/coco/images" ], "execution_count": null, "outputs": [] @@ -652,8 +651,8 @@ "id": "29GJXAP_lPrt" }, "source": [ - "# Run YOLOv5s on COCO test-dev2017 using --task test\n", - "!python val.py --weights yolov5s.pt --data coco.yaml --task test" + "# Run YOLOv5x on COCO test\n", + "!python val.py --weights yolov5x.pt --data coco.yaml --img 640 --iou 0.65 --half --task test" ], "execution_count": null, "outputs": [] diff --git a/utils/autobatch.py b/utils/autobatch.py new file mode 100644 index 000000000000..cf65502d5608 --- /dev/null +++ b/utils/autobatch.py @@ -0,0 +1,58 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Auto-batch utils +""" + +from copy import deepcopy + +import numpy as np +import torch +from torch.cuda import amp + +from utils.general import colorstr +from utils.torch_utils import profile + + +def check_train_batch_size(model, imgsz=640): + # Check YOLOv5 training batch size + with amp.autocast(): + return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size + + +def autobatch(model, imgsz=640, fraction=0.9, batch_size=16): + # Automatically estimate best batch size to use `fraction` of available CUDA memory + # Usage: + # import torch + # from utils.autobatch import autobatch + # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False) + # print(autobatch(model)) + + prefix = colorstr('autobatch: ') + print(f'{prefix}Computing optimal batch size for --imgsz {imgsz}') + device = next(model.parameters()).device # get model device + if device.type == 'cpu': + print(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') + return batch_size + + d = str(device).upper() # 'CUDA:0' + t = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 # (GB) + r = torch.cuda.memory_reserved(device) / 1024 ** 3 # (GB) + a = torch.cuda.memory_allocated(device) / 1024 ** 3 # (GB) + f = t - (r + a) # free inside reserved + print(f'{prefix}{d} {t:.3g}G total, {r:.3g}G reserved, {a:.3g}G allocated, {f:.3g}G free') + + batch_sizes = [1, 2, 4, 8, 16] + try: + img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes] + y = profile(img, model, n=3, device=device) + except Exception as e: + print(f'{prefix}{e}') + + y = [x[2] for x in y if x] # memory [2] + batch_sizes = batch_sizes[:len(y)] + p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit + b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) + print(f'{prefix}batch-size {b} estimated to utilize {d} {t * fraction:.3g}G/{t:.3g}G ({fraction * 100:.0f}%)') + return b + +# autobatch(torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 352ecf572c9f..6f52f9a3728d 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -126,7 +126,7 @@ def profile(input, ops, n=10, device=None): _ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward() t[2] = time_sync() except Exception as e: # no backward method - print(e) + # print(e) # for debug t[2] = float('nan') tf += (t[1] - t[0]) * 1000 / n # ms per op forward tb += (t[2] - t[1]) * 1000 / n # ms per op backward @@ -299,7 +299,10 @@ def __call__(self, epoch, fitness): self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch stop = delta >= self.patience # stop training if patience exceeded if stop: - LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.') + LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. ' + f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n' + f'To update EarlyStopping(patience={self.patience}) pass a new patience value, ' + f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.') return stop