Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detection cropping+saving feature addition for detect.py and PyTorch Hub #2827

Merged
merged 16 commits into from
Apr 20, 2021
Next Next commit
Update detect.py
  • Loading branch information
burhr2 committed Apr 17, 2021
commit f2066945cb179f9eb5d208050b57167c866bbee9
27 changes: 24 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import time
from pathlib import Path
import numpy as np

import cv2
import torch
Expand All @@ -16,15 +17,16 @@


def detect(save_img=False):
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
source, weights, view_img, save_obj, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_obj, opt.save_txt, opt.img_size
save_img = not opt.nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://'))

# Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

(save_dir / 'cropped' if save_obj else save_dir).mkdir(parents=True, exist_ok=True) # make dir for cropped objects

# Initialize
set_logging()
device = select_device(opt.device)
Expand Down Expand Up @@ -85,7 +87,8 @@ def detect(save_img=False):
p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
else:
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)


im1=im0.copy() # making a copy of the original image
p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
Expand All @@ -101,6 +104,7 @@ def detect(save_img=False):
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

# Write results
k = 0 # counter for each object in an image
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
Expand All @@ -111,6 +115,22 @@ def detect(save_img=False):
if save_img or view_img: # Add bbox to image
label = f'{names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)

if save_obj: # save detected objects as a separate images
x,y,w,h=int(xyxy[0]), int(xyxy[1]), int(xyxy[2] - xyxy[0]), int(xyxy[3] - xyxy[1])
img_ = im1.astype(np.uint8)
crop_img=img_[y:y + h, x:x + w]

#!!Generating new file path for each detected object in an image !!!
filename=p.name
filename_no_extesion=filename.split('.')[0]
extension=filename.split('.')[1]
new_filename=str(filename_no_extesion) + '_' + str(k) + '.' + str(extension)
dir_path=os.path.join(save_dir,'cropped')
filepath=os.path.join(dir_path, new_filename)
print(filepath)
cv2.imwrite(filepath, crop_img)
k+=1

# Print time (inference + NMS)
print(f'{s}Done. ({t2 - t1:.3f}s)')
Expand Down Expand Up @@ -156,6 +176,7 @@ def detect(save_img=False):
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save_obj', action='store_false', help='save the detected object as separate image')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
Expand Down