From bf6eb9de522d5e991052f8106bde30e936628b79 Mon Sep 17 00:00:00 2001 From: Ding Yiwei <16083536+dingyiwei@users.noreply.github.com> Date: Thu, 1 Apr 2021 23:26:53 +0800 Subject: [PATCH] Add TransformerLayer, TransformerBlock, C3TR modules (#2333) * yolotr * transformer block * Remove bias in Transformer * Remove C3T * Remove a deprecated class * put the 2nd LayerNorm into the 2nd residual block * move example model to models/hub, rename to -transformer * Add module comments and TODOs * Remove LN in Transformer * Add comments for Transformer * Solve the problem of MA with DDP * cleanup * cleanup find_unused_parameters * PEP8 reformat Co-authored-by: DingYiwei <846414640@qq.com> Co-authored-by: Glenn Jocher --- models/common.py | 54 +++++++++++++++++++++++++++++ models/hub/yolov5s-transformer.yaml | 48 +++++++++++++++++++++++++ models/yolo.py | 4 +-- train.py | 4 ++- 4 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 models/hub/yolov5s-transformer.yaml diff --git a/models/common.py b/models/common.py index 5c0e571b752f..a25172dcfcac 100644 --- a/models/common.py +++ b/models/common.py @@ -43,6 +43,52 @@ def fuseforward(self, x): return self.act(self.conv(x)) +class TransformerLayer(nn.Module): + # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) + def __init__(self, c, num_heads): + super().__init__() + self.q = nn.Linear(c, c, bias=False) + self.k = nn.Linear(c, c, bias=False) + self.v = nn.Linear(c, c, bias=False) + self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) + self.fc1 = nn.Linear(c, c, bias=False) + self.fc2 = nn.Linear(c, c, bias=False) + + def forward(self, x): + x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x + x = self.fc2(self.fc1(x)) + x + return x + + +class TransformerBlock(nn.Module): + # Vision Transformer https://arxiv.org/abs/2010.11929 + def __init__(self, c1, c2, num_heads, num_layers): + super().__init__() + self.conv = None + if c1 != c2: + self.conv = Conv(c1, c2) + self.linear = nn.Linear(c2, c2) # learnable position embedding + self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)]) + self.c2 = c2 + + def forward(self, x): + if self.conv is not None: + x = self.conv(x) + b, _, w, h = x.shape + p = x.flatten(2) + p = p.unsqueeze(0) + p = p.transpose(0, 3) + p = p.squeeze(3) + e = self.linear(p) + x = p + e + + x = self.tr(x) + x = x.unsqueeze(3) + x = x.transpose(0, 3) + x = x.reshape(b, self.c2, w, h) + return x + + class Bottleneck(nn.Module): # Standard bottleneck def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion @@ -90,6 +136,14 @@ def forward(self, x): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) +class C3TR(C3): + # C3 module with TransformerBlock() + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = TransformerBlock(c_, c_, 4, n) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): diff --git a/models/hub/yolov5s-transformer.yaml b/models/hub/yolov5s-transformer.yaml new file mode 100644 index 000000000000..f2d666722b30 --- /dev/null +++ b/models/hub/yolov5s-transformer.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3TR, [1024, False]], # 9 <-------- C3TR() Transformer module + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolo.py b/models/yolo.py index e5c676dae558..f730a1efa3b3 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -215,13 +215,13 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = max(round(n * gd), 1) if n > 1 else n # depth gain if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, - C3]: + C3, C3TR]: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] - if m in [BottleneckCSP, C3]: + if m in [BottleneckCSP, C3, C3TR]: args.insert(2, n) # number of repeats n = 1 elif m is nn.BatchNorm2d: diff --git a/train.py b/train.py index d55833bf45a3..1f2b467e732b 100644 --- a/train.py +++ b/train.py @@ -218,7 +218,9 @@ def train(hyp, opt, device, tb_writer=None): # DDP mode if cuda and rank != -1: - model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) + model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank, + # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698 + find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules())) # Model parameters hyp['box'] *= 3. / nl # scale to layers