-
Notifications
You must be signed in to change notification settings - Fork 0
/
deeplabV3plus.py
339 lines (315 loc) · 11.5 KB
/
deeplabV3plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from res2net import res2net50_48w_2s
BatchNorm2d=torch.nn.BatchNorm2d
class Bottleneck(nn.Module):
#'resnet网络的基本框架’
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
dilation=dilation,
padding=dilation,
bias=False,
)
self.bn2 = BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
# renet网络的构成部分
def __init__(self, nInputChannels, block, layers, os=16, pretrained=False):
self.inplanes = 64
super(ResNet, self).__init__()
if os == 16:
strides = [1, 2, 2, 1]
dilations = [1, 1, 1, 2]
blocks = [1, 2, 4]
elif os == 8:
strides = [1, 2, 1, 1]
dilations = [1, 1, 2, 2]
blocks = [1, 2, 1]
else:
raise NotImplementedError
# Modules
self.conv1 = nn.Conv2d(
nInputChannels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block, 64, layers[0], stride=strides[0], dilation=dilations[0]
)
self.layer2 = self._make_layer(
block, 128, layers[1], stride=strides[1], dilation=dilations[1]
)
self.layer3 = self._make_layer(
block, 256, layers[2], stride=strides[2], dilation=dilations[2]
)
self.layer4 = self._make_MG_unit(
block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3]
)
self._init_weight()
if pretrained:
self._load_pretrained_model()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, dilation, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
dilation=blocks[0] * dilation,
downsample=downsample,
)
)
self.inplanes = planes * block.expansion
for i in range(1, len(blocks)):
layers.append(
block(self.inplanes, planes, stride=1, dilation=blocks[i] * dilation)
)
return nn.Sequential(*layers)
def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
low_level_feat = x
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x, low_level_feat
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url(
"https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"
)
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def ResNet101(nInputChannels=3, os=16, pretrained=False):
model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained)
return model
class ASPP_module(nn.Module):
# ASpp模块的组成
def __init__(self, inplanes, planes, dilation):
super(ASPP_module, self).__init__()
if dilation == 1:
kernel_size = 1
padding = 0
else:
kernel_size = 3
padding = dilation
self.atrous_convolution = nn.Conv2d(
inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False,
)
self.bn = BatchNorm2d(planes)
self.relu = nn.ReLU()
self._init_weight()
def forward(self, x):
x = self.atrous_convolution(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class DeepLabv3_plus(nn.Module):
# 正式开始deeplabv3+的结构组成
def __init__(
self,
nInputChannels=3,
n_classes=21,
os=16,
pretrained=False,
freeze_bn=False,
_print=True,
):
if _print:
print("Constructing DeepLabv3+ model...")
print("Backbone: Resnet-101")
print("Number of classes: {}".format(n_classes))
print("Output stride: {}".format(os))
print("Number of Input Channels: {}".format(nInputChannels))
super(DeepLabv3_plus, self).__init__()
# Atrous Conv 首先获得从resnet101中提取的features map
self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained)
# self.resnet_features = res2net50_48w_2s(pretrained=True)
# ASPP,挑选参数
if os == 16:
dilations = [1, 6, 12, 18]
elif os == 8:
dilations = [1, 12, 24, 36]
else:
raise NotImplementedError
# 四个不同带洞卷积的设置,获取不同感受野
self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0])
self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1])
self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2])
self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3])
self.relu = nn.ReLU()
# 全局平均池化层的设置
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(2048, 256, 1, stride=1, bias=False),
BatchNorm2d(256),
nn.ReLU(),
)
self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
self.bn1 = BatchNorm2d(256)
# adopt [1x1, 48] for channel reduction.
self.conv2 = nn.Conv2d(256, 48, 1, bias=False)
self.bn2 = BatchNorm2d(48)
# 结构图中的解码部分的最后一个3*3的卷积块
self.last_conv = nn.Sequential(
nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, n_classes, kernel_size=1, stride=1),
)
if freeze_bn:
self._freeze_bn()
# 前向传播
def forward(self, input):
x, low_level_features = self.resnet_features(input)
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
# x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
# 把四个ASPP模块以及全局池化层拼接起来
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
# 上采样
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = F.interpolate(
x,
size=(
int(math.ceil(input.size()[-2] / 4)),
int(math.ceil(input.size()[-1] / 4)),
),
mode="bilinear",
align_corners=True,
)
low_level_features = self.conv2(low_level_features)
low_level_features = self.bn2(low_level_features)
low_level_features = self.relu(low_level_features)
# 拼接低层次的特征,然后再通过插值获取原图大小的结果
x = torch.cat((x, low_level_features), dim=1)
x = self.last_conv(x)
x = F.interpolate(x, size=input.size()[2:], mode="bilinear", align_corners=True)
return x
def _freeze_bn(self):
for m in self.modules():
if isinstance(m, BatchNorm2d):
m.eval()
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if __name__ == "__main__":
from torchsummaryX import summary
# model = ResNet101(nInputChannels=3, os=8, pretrained=False)
# input = torch.rand(1,3,512,512)
# output = model(input)
# summary(model,input)
model = DeepLabv3_plus(
nInputChannels=3, n_classes=21, os=16, pretrained=True, _print=True
)
model.eval()
image = torch.randn(1, 3, 512, 512)
with torch.no_grad():
output = model.forward(image)
print(output.size())
# summary(model, image)