Skip to content

Commit

Permalink
PyTorch Hub models default to CUDA:0 if available (ultralytics#2472)
Browse files Browse the repository at this point in the history
* PyTorch Hub models default to CUDA:0 if available

* device as string bug fix
  • Loading branch information
glenn-jocher committed Mar 15, 2021
1 parent 63b9634 commit d585238
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
4 changes: 3 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from models.yolo import Model
from utils.general import set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device

dependencies = ['torch', 'yaml']
set_logging()
Expand Down Expand Up @@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape):
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
return model
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
return model.to(device)

except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
Expand Down
4 changes: 2 additions & 2 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
# Display cache
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
if exists:
d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'

Expand Down Expand Up @@ -485,7 +485,7 @@ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')

pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"

if nf == 0:
Expand Down
2 changes: 1 addition & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def check_git_status():
f"Use 'git pull' to update or 'git clone {url}' to download latest."
else:
s = f'up to date with {url} ✅'
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
except Exception as e:
print(e)

Expand Down
6 changes: 3 additions & 3 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# PyTorch utils

import logging
import math
import os
import platform
import subprocess
import time
from contextlib import contextmanager
Expand Down Expand Up @@ -53,7 +53,7 @@ def git_describe():

def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
s = f'YOLOv5 🚀 {git_describe()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
Expand All @@ -73,7 +73,7 @@ def select_device(device='', batch_size=None):
else:
s += 'CPU\n'

logger.info(s) # skip a line
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
return torch.device('cuda:0' if cuda else 'cpu')


Expand Down

0 comments on commit d585238

Please sign in to comment.