Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransformerLayer, TransformerBlock, C3TR modules #2333

Merged
merged 18 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
transformer block
  • Loading branch information
dingyiwei committed Feb 11, 2021
commit b479d24feee94a3df9f98a9b5716fe4a3a32cedc
58 changes: 58 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,57 @@ def forward(self, x):
return x


class TransformerLayer(nn.Module):
def __init__(self, c, num_heads):
super().__init__()

self.ln1 = nn.LayerNorm(c)
self.q = nn.Linear(c, c)
self.k = nn.Linear(c, c)
self.v = nn.Linear(c, c)
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
self.ln2 = nn.LayerNorm(c)
self.fc1 = nn.Linear(c, c)
self.fc2 = nn.Linear(c, c)
self.act = nn.SiLU()

def forward(self, x):
x_ = self.ln1(x)
x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x
x = self.ln2(x)
x = self.fc2(self.act(self.fc1(x))) + x
return x


class TransformerBlock(nn.Module):
def __init__(self, c1, c2, num_heads, num_layers):
super().__init__()

self.conv = None
if c1 != c2:
self.conv = Conv(c1, c2)
self.linear = nn.Linear(c2, c2)
self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
self.c2 = c2

def forward(self, x):
if self.conv is not None:
x = self.conv(x)
b, _, w, h = x.shape
p = x.flatten(2)
p = p.unsqueeze(0)
p = p.transpose(0, 3)
p = p.squeeze(3)
e = self.linear(p)
x = p + e

x = self.tr(x)
x = x.unsqueeze(3)
x = x.transpose(0, 3)
x = x.reshape(b, self.c2, w, h)
return x


class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
Expand Down Expand Up @@ -138,6 +189,13 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
self.m = nn.Sequential(*[BoT(c_, c_, shortcut, g, e=1.0) for _ in range(n)])


class C3TR(C3):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e)
self.m = TransformerBlock(c_, c_, 4, n)


class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)):
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
pass

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3T]:
if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3T, C3TR]:
c1, c2 = ch[f], args[0]

# Normal
Expand All @@ -232,7 +232,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
# c2 = make_divisible(c2, 8) if c2 != no else c2

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
if m in [BottleneckCSP, C3, C3T, C3TR]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
Expand Down
2 changes: 1 addition & 1 deletion models/yolotrs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ backbone:
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, C3T, [1024, False]], # 9
[-1, 3, C3TR, [1024, False]], # 9
]

# YOLOv5 head
Expand Down