forked from Jongchan/attention-module
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit e4ee180
Showing
6 changed files
with
690 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
import math | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class Flatten(nn.Module): | ||
def forward(self, x): | ||
return x.view(x.size(0), -1) | ||
class ChannelGate(nn.Module): | ||
def __init__(self, gate_channel, reduction_ratio==16, num_layers=1): | ||
super(ChannelGate, self).__init__() | ||
self.gate_activation = gate_activation | ||
self.gate_c = nn.Sequential() | ||
self.gate_c.add_module( 'flatten', Flatten() ) | ||
gate_channels = [gate_channel] | ||
gate_channels += [gate_channel // reduction_ratio] * num_layers | ||
gate_channels += [gate_channel] | ||
for i in range( len(gate_channels) - 2 ): | ||
self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) ) | ||
self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) ) | ||
self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() ) | ||
self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) ) | ||
def forward(self, in_tensor): | ||
avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) ) | ||
return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) | ||
|
||
class SpatialGate(nn.Module): | ||
def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4): | ||
super(SpatialGate, self).__init__() | ||
self.gate_s = nn.Sequential() | ||
self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1)) | ||
self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) ) | ||
self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() ) | ||
for i in range( dilation_conv_num ): | ||
self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \ | ||
padding=dilation_val, dilation=dilation_val) ) | ||
self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) ) | ||
self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() ) | ||
self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) ) | ||
def forward(self, in_tensor): | ||
return self.gate_s( in_tensor ).expand_as(in_tensor) | ||
class BAM(nn.Module): | ||
def __init__(self, gate_channel): | ||
super(BAM, self).__init__() | ||
self.channel_att = ChannelGate(gate_channel) | ||
self.spatial_att = SpatialGate(gate_channel) | ||
def forward(self,in_tensor): | ||
att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) ) | ||
return att * in_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import math | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class BasicConv(nn.Module): | ||
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): | ||
super(BasicConv, self).__init__() | ||
self.out_channels = out_planes | ||
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) | ||
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None | ||
self.relu = nn.ReLU() if relu else None | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
if self.bn is not None: | ||
x = self.bn(x) | ||
if self.relu is not None: | ||
x = self.relu(x) | ||
return x | ||
|
||
class Flatten(nn.Module): | ||
def forward(self, x): | ||
return x.view(x.size(0), -1) | ||
|
||
class ChannelGate(nn.Module): | ||
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): | ||
super(ChannelGate, self).__init__() | ||
self.gate_channels = gate_channels | ||
self.mlp = nn.Sequential( | ||
Flatten(), | ||
nn.Linear(gate_channels, gate_channels // reduction_ratio), | ||
nn.ReLU(), | ||
nn.Linear(gate_channels // reduction_ratio, gate_channels) | ||
) | ||
self.pool_types = pool_types | ||
def forward(self, x): | ||
channel_att_sum = None | ||
for pool_type in self.pool_types: | ||
if pool_type=='avg': | ||
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) | ||
channel_att_raw = self.mlp( avg_pool ) | ||
elif pool_type=='max': | ||
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) | ||
channel_att_raw = self.mlp( max_pool ) | ||
elif pool_type=='lp': | ||
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) | ||
channel_att_raw = self.mlp( lp_pool ) | ||
elif pool_type=='lse': | ||
# LSE pool only | ||
lse_pool = logsumexp_2d(x) | ||
channel_att_raw = self.mlp( lse_pool ) | ||
|
||
if channel_att_sum is None: | ||
channel_att_sum = channel_att_raw | ||
else: | ||
channel_att_sum = channel_att_sum + channel_att_raw | ||
|
||
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) | ||
return x * scale | ||
|
||
def logsumexp_2d(tensor): | ||
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) | ||
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) | ||
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() | ||
return outputs | ||
|
||
class ChannelPool(nn.Module): | ||
def forward(self, x): | ||
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) | ||
|
||
class SpatialGate(nn.Module): | ||
def __init__(self): | ||
super(SpatialGate, self).__init__() | ||
kernel_size = 7 | ||
self.compress = ChannelPool() | ||
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) | ||
def forward(self, x): | ||
x_compress = self.compress(x) | ||
x_out = self.spatial(x_compress) | ||
scale = F.sigmoid(x_out) # broadcasting | ||
return x * scale | ||
|
||
class CBAM(nn.Module): | ||
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): | ||
super(CBAM, self).__init__() | ||
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) | ||
self.no_spatial=no_spatial | ||
if not no_spatial: | ||
self.SpatialGate = SpatialGate() | ||
def forward(self, x): | ||
x_out = self.ChannelGate(x) | ||
if not self.no_spatial: | ||
x_out = self.SpatialGate(x_out) | ||
return x_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import math | ||
from torch.nn import init | ||
from cbam import * | ||
from bam import * | ||
|
||
def conv3x3(in_planes, out_planes, stride=1): | ||
"3x3 convolution with padding" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
|
||
class BasicBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): | ||
super(BasicBlock, self).__init__() | ||
self.conv1 = conv3x3(inplanes, planes, stride) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = conv3x3(planes, planes) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
if use_cbam: | ||
self.cbam = CBAM( planes, 16 ) | ||
else: | ||
self.cbam = None | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
if not self.cbam is None: | ||
out = self.cbam(out) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
class Bottleneck(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): | ||
super(Bottleneck, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * 4) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
if use_cbam: | ||
self.cbam = CBAM( planes, 16 ) | ||
else: | ||
self.cbam = None | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
if not self.cbam is None: | ||
out = self.cbam(out) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
class ResNet(nn.Module): | ||
def __init__(self, block, layers, network_type, num_classes, att_type=None): | ||
self.inplanes = 64 | ||
super(ResNet, self).__init__() | ||
self.network_type = network_type | ||
# different model config between ImageNet and CIFAR | ||
if network_type == "ImageNet": | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.avgpool = nn.AvgPool2d(7) | ||
else: | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | ||
|
||
self.bn1 = nn.BatchNorm2d(64) | ||
self.relu = nn.ReLU(inplace=True) | ||
|
||
if att_type=='BAM': | ||
self.bam1 = BAM(64*block.expansion) | ||
self.bam2 = BAM(128*block.expansion) | ||
self.bam3 = BAM(256*block.expansion) | ||
else: | ||
self.bam1, self.bam2, self.bam3 = None, None, None | ||
|
||
self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type) | ||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type) | ||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type) | ||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type) | ||
|
||
self.fc = nn.Linear(512 * block.expansion, num_classes) | ||
|
||
init.kaiming_normal(self.fc.weight) | ||
for key in self.state_dict(): | ||
if key.split('.')[-1]=="weight": | ||
if "conv" in key: | ||
init.kaiming_normal(self.state_dict()[key], mode='fan_out') | ||
if "bn" in key: | ||
if "SpatialGate" in key: | ||
self.state_dict()[key][...] = 0 | ||
else: | ||
self.state_dict()[key][...] = 1 | ||
elif key.split(".")[-1]=='bias': | ||
self.state_dict()[key][...] = 0 | ||
|
||
def _make_layer(self, block, planes, blocks, att_type): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM')) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM')) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
if self.network_type == "ImageNet": | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
if not self.bam1 is None: | ||
x = self.bam1(x) | ||
|
||
x = self.layer2(x) | ||
if not self.bam2 is None: | ||
x = self.bam2(x) | ||
|
||
x = self.layer3(x) | ||
if not self.bam3 is None: | ||
x = self.bam3(x) | ||
|
||
x = self.layer4(x) | ||
|
||
if self.network_type == "ImageNet": | ||
x = self.avgpool(x) | ||
else: | ||
x = F.avg_pool2d(x, 4) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
return x | ||
|
||
def ResidualNet(network_type, depth, num_classes, att_type): | ||
|
||
assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100" | ||
assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101' | ||
|
||
if depth == 18: | ||
model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type) | ||
|
||
elif depth == 34: | ||
model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type) | ||
|
||
elif depth == 50: | ||
model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type) | ||
|
||
elif depth == 101: | ||
model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 \ | ||
python train_imagenet.py \ | ||
--ngpu 4 \ | ||
--workers 20 \ | ||
--arch resnet --depth 50 \ | ||
--epochs 100 \ | ||
--batch-size 256 --lr 0.1 \ | ||
--att-type BAM \ | ||
--prefix RESNET50_IMAGENET_BAM \ | ||
./data/ImageNet/ \ | ||
> logs/RESNET50_IMAGENET_BAM.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 \ | ||
python train_imagenet.py \ | ||
--ngpu 4 \ | ||
--workers 20 \ | ||
--arch resnet --depth 50 \ | ||
--epochs 100 \ | ||
--batch-size 256 --lr 0.1 \ | ||
--att-type CBAM \ | ||
--prefix RESNET50_IMAGENET_CBAM \ | ||
./data/ImageNet/ \ | ||
> logs/RESNET50_IMAGENET_CBAM.log |
Oops, something went wrong.