Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update #5

Merged
merged 26 commits into from
Jul 5, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update yolo.py
  • Loading branch information
glenn-jocher committed Jul 5, 2020
commit 4e2d24602d246231694ba1b4d3bf3bd01f027ea4
20 changes: 12 additions & 8 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input cha

# Build strides, anchors
m = self.model[-1] # Detect()
m.stride = torch.tensor([128 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 128, 128))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride
if isinstance(m, Detect):
s = 128 # 2x min stride
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride
self._initialize_biases() # only run once
# print('Strides: %s' % m.stride.tolist())

# Init weights, biases
torch_utils.initialize_weights(self)
Expand Down Expand Up @@ -146,7 +150,7 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers


def parse_model(md, ch): # model_dict, input_channels(3)
print('\n%3s%15s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = md['anchors'], md['nc'], md['depth_multiple'], md['width_multiple']
na = (len(anchors[0]) // 2) # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
Expand All @@ -161,7 +165,7 @@ def parse_model(md, ch): # model_dict, input_channels(3)
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, BottleneckCSP, CrossConv]:
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0]

# Normal
Expand All @@ -182,7 +186,7 @@ def parse_model(md, ch): # model_dict, input_channels(3)
# c2 = make_divisible(c2, 8) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m is BottleneckCSP:
if m in [BottleneckCSP, C3]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
Expand All @@ -198,7 +202,7 @@ def parse_model(md, ch): # model_dict, input_channels(3)
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
print('%3s%15s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
print('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
ch.append(c2)
Expand Down