Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransformerLayer, TransformerBlock, C3TR modules #2333

Merged
merged 18 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,56 @@ def fuseforward(self, x):
return self.act(self.conv(x))


class TransformerLayer(nn.Module):
def __init__(self, c, num_heads):
super().__init__()

self.ln1 = nn.LayerNorm(c)
self.q = nn.Linear(c, c, bias=False)
self.k = nn.Linear(c, c, bias=False)
self.v = nn.Linear(c, c, bias=False)
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
self.ln2 = nn.LayerNorm(c)
self.fc1 = nn.Linear(c, c, bias=False)
self.fc2 = nn.Linear(c, c, bias=False)

def forward(self, x):
x_ = self.ln1(x)
x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x
x = self.ln2(x)
x = self.fc2(self.fc1(x)) + x
return x


class TransformerBlock(nn.Module):
def __init__(self, c1, c2, num_heads, num_layers):
super().__init__()

self.conv = None
if c1 != c2:
self.conv = Conv(c1, c2)
self.linear = nn.Linear(c2, c2)
self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
self.c2 = c2

def forward(self, x):
if self.conv is not None:
x = self.conv(x)
b, _, w, h = x.shape
p = x.flatten(2)
p = p.unsqueeze(0)
p = p.transpose(0, 3)
p = p.squeeze(3)
e = self.linear(p)
x = p + e

x = self.tr(x)
x = x.unsqueeze(3)
x = x.transpose(0, 3)
x = x.reshape(b, self.c2, w, h)
return x


class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
Expand Down Expand Up @@ -88,6 +138,13 @@ def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))


class C3TR(C3):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e)
self.m = TransformerBlock(c_, c_, 4, n)


class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)):
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
C3]:
C3, C3TR]:
c1, c2 = ch[f], args[0]

# Normal
Expand All @@ -233,7 +233,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
# c2 = make_divisible(c2, 8) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
if m in [BottleneckCSP, C3, C3TR]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
Expand Down
48 changes: 48 additions & 0 deletions models/yolotrs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, C3TR, [1024, False]], # 9
]

# YOLOv5 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)

[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)

[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)

[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]