Skip to content

Commit

Permalink
Add autoShape() speed profiling (ultralytics#2459)
Browse files Browse the repository at this point in the history
* Add autoShape() speed profiling

* Update common.py

* Create README.md

* Update hubconf.py

* cleanuip
  • Loading branch information
glenn-jocher committed Mar 14, 2021
1 parent 435d853 commit 13e90ca
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ To run **batched inference** with YOLOv5 and [PyTorch Hub](https://github.com/ul
import torch

# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')

# Images
dir = 'https://github.com/ultralytics/yolov5/raw/master/data/images/'
imgs = [dir + f for f in ('zidane.jpg', 'bus.jpg')] # batched list of images
imgs = [dir + f for f in ('zidane.jpg', 'bus.jpg')] # batch of images

# Inference
results = model(imgs)
Expand Down
8 changes: 4 additions & 4 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create(name, pretrained, channels, classes, autoshape):
raise Exception(s) from e


def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True):
"""YOLOv5-small model from https://github.com/ultralytics/yolov5
Arguments:
Expand All @@ -65,7 +65,7 @@ def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
return create('yolov5s', pretrained, channels, classes, autoshape)


def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True):
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5
Arguments:
Expand All @@ -79,7 +79,7 @@ def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
return create('yolov5m', pretrained, channels, classes, autoshape)


def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True):
"""YOLOv5-large model from https://github.com/ultralytics/yolov5
Arguments:
Expand All @@ -93,7 +93,7 @@ def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
return create('yolov5l', pretrained, channels, classes, autoshape)


def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True):
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True):
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
Arguments:
Expand Down
14 changes: 11 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
from utils.plots import color_list, plot_one_box
from utils.torch_utils import time_synchronized


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -190,6 +191,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
# torch: = torch.zeros(16,3,720,1280) # BCHW
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

t = [time_synchronized()]
p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
Expand All @@ -216,22 +218,25 @@ def forward(self, imgs, size=640, augment=False, profile=False):
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized())

# Inference
with torch.no_grad():
y = self.model(x, augment, profile)[0] # forward
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
t.append(time_synchronized())

# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_synchronized())

return Detections(imgs, y, files, self.names)
return Detections(imgs, y, files, t, self.names, x.shape)


class Detections:
# detections class for YOLOv5 inference results
def __init__(self, imgs, pred, files, names=None):
def __init__(self, imgs, pred, files, times, names=None, shape=None):
super(Detections, self).__init__()
d = pred[0].device # device
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
Expand All @@ -244,6 +249,8 @@ def __init__(self, imgs, pred, files, names=None):
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred)
self.t = ((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
self.s = shape # inference BCHW shape

def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
colors = color_list()
Expand Down Expand Up @@ -271,6 +278,7 @@ def display(self, pprint=False, show=False, save=False, render=False, save_dir='

def print(self):
self.display(pprint=True) # print results
print(f'Speed: %.1f/%.1f/%.1f ms pre-process/inference/NMS per image at shape {tuple(self.s)}' % tuple(self.t))

def show(self):
self.display(show=True) # show results
Expand Down

0 comments on commit 13e90ca

Please sign in to comment.