Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update inference default to multi_label=False #2252

Merged
merged 4 commits into from
Feb 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import requests
import torch
import torch.nn as nn
from PIL import Image, ImageDraw
from PIL import Image

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
Expand Down
8 changes: 4 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test(data,
with torch.no_grad():
# Run model
t = time_synchronized()
inf_out, train_out = model(img, augment=augment) # inference and training outputs
out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t

# Compute loss
Expand All @@ -117,11 +117,11 @@ def test(data,
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb)
out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True)
t1 += time_synchronized() - t

# Statistics per image
for si, pred in enumerate(output):
for si, pred in enumerate(out):
labels = targets[targets[:, 0] == si, 1:]
nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class
Expand Down Expand Up @@ -209,7 +209,7 @@ def test(data,
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start()

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
Expand Down
9 changes: 5 additions & 4 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,12 @@ def wh_iou(wh1, wh2):
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)


def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results

Returns:
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""

nc = prediction.shape[2] - 5 # number of classes
Expand All @@ -406,7 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS

t = time.time()
Expand Down
2 changes: 1 addition & 1 deletion utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def butter_lowpass(cutoff, fs, order):
return filtfilt(b, a, data) # forward-backward filter


def plot_one_box(x, img, color=None, label=None, line_thickness=None):
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
# Plots one bounding box on image img
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
Expand Down