Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autobatch feature for best batch-size estimation #5092

Merged
merged 47 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7798401
Autobatch
glenn-jocher Oct 7, 2021
38ee33c
fix mem
glenn-jocher Oct 7, 2021
9bab757
fix mem2
glenn-jocher Oct 7, 2021
8a0ee56
Update
glenn-jocher Oct 7, 2021
8a68891
Update
glenn-jocher Oct 7, 2021
6f67028
Update
glenn-jocher Oct 7, 2021
ccd47a0
Update
glenn-jocher Oct 7, 2021
05d7860
Update
glenn-jocher Oct 7, 2021
78cbd2a
Update
glenn-jocher Oct 7, 2021
3b34bd4
Update
glenn-jocher Oct 7, 2021
cc09ecf
Update
glenn-jocher Oct 7, 2021
6d0b3e9
Update
glenn-jocher Oct 7, 2021
9282c21
Update
glenn-jocher Oct 7, 2021
45ddb57
Update
glenn-jocher Oct 7, 2021
b1a57d1
Update
glenn-jocher Oct 7, 2021
122733d
Update
glenn-jocher Oct 7, 2021
13c4996
Update
glenn-jocher Oct 7, 2021
bd34ab8
Update
glenn-jocher Oct 7, 2021
bbe56b8
Update
glenn-jocher Oct 7, 2021
aef68c9
Update
glenn-jocher Oct 7, 2021
3faf055
Update
glenn-jocher Oct 7, 2021
831593b
Update
glenn-jocher Oct 7, 2021
65e3bf6
Update
glenn-jocher Oct 7, 2021
6a0c4d2
Update
glenn-jocher Oct 7, 2021
6fa9834
Update train.py
glenn-jocher Oct 8, 2021
888f55c
print result
glenn-jocher Oct 8, 2021
d2f47bc
Cleanup print result
glenn-jocher Oct 8, 2021
c94026a
swap fix in call
glenn-jocher Oct 8, 2021
afdfcfb
to 64
glenn-jocher Oct 8, 2021
ab7cc12
use total
glenn-jocher Oct 8, 2021
a036dd4
fix
glenn-jocher Oct 8, 2021
f6f80ed
fix
glenn-jocher Oct 8, 2021
e601f42
fix
glenn-jocher Oct 8, 2021
f55ad0b
fix
glenn-jocher Oct 8, 2021
58ed6af
fix
glenn-jocher Oct 8, 2021
18f5dd3
Update
glenn-jocher Oct 8, 2021
4b39534
Update
glenn-jocher Oct 8, 2021
1c9b42a
Update
glenn-jocher Oct 8, 2021
5c2e235
Update
glenn-jocher Oct 8, 2021
08f8e17
Update
glenn-jocher Oct 8, 2021
a9c00fa
Update
glenn-jocher Oct 8, 2021
ccabcd3
Update
glenn-jocher Oct 8, 2021
af25dbc
Cleanup printing
glenn-jocher Oct 8, 2021
3e1c74f
Update final printout
glenn-jocher Oct 25, 2021
602cf9a
Update autobatch.py
glenn-jocher Oct 25, 2021
126c13a
Update autobatch.py
glenn-jocher Oct 25, 2021
327954f
Update autobatch.py
glenn-jocher Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
Expand Down Expand Up @@ -131,6 +132,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
print(f'freezing {k}')
v.requires_grad = False

# Image size
gs = max(int(model.stride.max()), 32) # grid size (max stride)
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple

# Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz)

# Optimizer
nbs = 64 # nominal batch size
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
Expand Down Expand Up @@ -190,11 +199,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

del ckpt, csd

# Image sizes
gs = max(int(model.stride.max()), 32) # grid size (max stride)
nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple

# DP mode
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
Expand Down Expand Up @@ -242,6 +246,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

# Model parameters
nl = model.model[-1].nl # number of detection layers (to scale hyps)
hyp['box'] *= 3. / nl # scale to layers
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
Expand Down Expand Up @@ -440,7 +445,7 @@ def parse_opt(known=False):
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
Expand Down
56 changes: 56 additions & 0 deletions utils/autobatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# YOLOv5 πŸš€ by Ultralytics, GPL-3.0 license
"""
Auto-batch utils
"""

from copy import deepcopy

import numpy as np
import torch
from torch.cuda import amp

from utils.general import colorstr
from utils.torch_utils import profile


def check_train_batch_size(model, imgsz=640):
# Check YOLOv5 training batch size
with amp.autocast():
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size


def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
# Automatically estimate best batch size to use `fraction` of available CUDA memory
# Usage:
# import torch
# from utils.autobatch import autobatch
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
# print(autobatch(model))

prefix = colorstr('autobatch: ')
print(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
print(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
return batch_size

d = str(device).upper() # 'CUDA:0'
t = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 # (GB)
r = torch.cuda.memory_reserved(device) / 1024 ** 3 # (GB)
a = torch.cuda.memory_allocated(device) / 1024 ** 3 # (GB)
f = t - (r + a) # free inside reserved
print(f'{prefix}{d} {t:.3g}G total, {r:.3g}G reserved, {a:.3g}G allocated, {f:.3g}G free')

batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
y = profile(img, model, n=3, device=device)
except Exception as e:
print(f'{prefix}{e}')

y = [x[2] for x in y if x] # memory [2]
batch_sizes = batch_sizes[:len(y)]
p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
print(f'{prefix}Using colorstr(batch-size {b}) for {d} {t * fraction:.3g}G/{t:.3g}G ({fraction * 100:.0f}%)')
return b
2 changes: 1 addition & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def profile(input, ops, n=10, device=None):
_ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
t[2] = time_sync()
except Exception as e: # no backward method
print(e)
# print(e) # for debug
t[2] = float('nan')
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
Expand Down