Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
shwoo93 committed Oct 7, 2018
0 parents commit e4ee180
Show file tree
Hide file tree
Showing 6 changed files with 690 additions and 0 deletions.
49 changes: 49 additions & 0 deletions MODELS/bam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channel, reduction_ratio==16, num_layers=1):
super(ChannelGate, self).__init__()
self.gate_activation = gate_activation
self.gate_c = nn.Sequential()
self.gate_c.add_module( 'flatten', Flatten() )
gate_channels = [gate_channel]
gate_channels += [gate_channel // reduction_ratio] * num_layers
gate_channels += [gate_channel]
for i in range( len(gate_channels) - 2 ):
self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) )
self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) )
def forward(self, in_tensor):
avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) )
return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)

class SpatialGate(nn.Module):
def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
super(SpatialGate, self).__init__()
self.gate_s = nn.Sequential()
self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
for i in range( dilation_conv_num ):
self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \
padding=dilation_val, dilation=dilation_val) )
self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) )
def forward(self, in_tensor):
return self.gate_s( in_tensor ).expand_as(in_tensor)
class BAM(nn.Module):
def __init__(self, gate_channel):
super(BAM, self).__init__()
self.channel_att = ChannelGate(gate_channel)
self.spatial_att = SpatialGate(gate_channel)
def forward(self,in_tensor):
att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) )
return att * in_tensor
95 changes: 95 additions & 0 deletions MODELS/cbam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None

def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x

class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )

if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw

scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale

def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs

class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale

class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
205 changes: 205 additions & 0 deletions MODELS/model_resnet_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import init
from cbam import *
from bam import *

def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)

class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride

if use_cbam:
self.cbam = CBAM( planes, 16 )
else:
self.cbam = None

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
residual = self.downsample(x)

if not self.cbam is None:
out = self.cbam(out)

out += residual
out = self.relu(out)

return out

class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

if use_cbam:
self.cbam = CBAM( planes, 16 )
else:
self.cbam = None

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

if not self.cbam is None:
out = self.cbam(out)

out += residual
out = self.relu(out)

return out

class ResNet(nn.Module):
def __init__(self, block, layers, network_type, num_classes, att_type=None):
self.inplanes = 64
super(ResNet, self).__init__()
self.network_type = network_type
# different model config between ImageNet and CIFAR
if network_type == "ImageNet":
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.avgpool = nn.AvgPool2d(7)
else:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)

if att_type=='BAM':
self.bam1 = BAM(64*block.expansion)
self.bam2 = BAM(128*block.expansion)
self.bam3 = BAM(256*block.expansion)
else:
self.bam1, self.bam2, self.bam3 = None, None, None

self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type)

self.fc = nn.Linear(512 * block.expansion, num_classes)

init.kaiming_normal(self.fc.weight)
for key in self.state_dict():
if key.split('.')[-1]=="weight":
if "conv" in key:
init.kaiming_normal(self.state_dict()[key], mode='fan_out')
if "bn" in key:
if "SpatialGate" in key:
self.state_dict()[key][...] = 0
else:
self.state_dict()[key][...] = 1
elif key.split(".")[-1]=='bias':
self.state_dict()[key][...] = 0

def _make_layer(self, block, planes, blocks, att_type):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM'))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM'))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.network_type == "ImageNet":
x = self.maxpool(x)

x = self.layer1(x)
if not self.bam1 is None:
x = self.bam1(x)

x = self.layer2(x)
if not self.bam2 is None:
x = self.bam2(x)

x = self.layer3(x)
if not self.bam3 is None:
x = self.bam3(x)

x = self.layer4(x)

if self.network_type == "ImageNet":
x = self.avgpool(x)
else:
x = F.avg_pool2d(x, 4)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

def ResidualNet(network_type, depth, num_classes, att_type):

assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100"
assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101'

if depth == 18:
model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type)

elif depth == 34:
model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type)

elif depth == 50:
model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type)

elif depth == 101:
model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type)

return model
11 changes: 11 additions & 0 deletions scripts/train_imagenet_resnet50_bam.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python train_imagenet.py \
--ngpu 4 \
--workers 20 \
--arch resnet --depth 50 \
--epochs 100 \
--batch-size 256 --lr 0.1 \
--att-type BAM \
--prefix RESNET50_IMAGENET_BAM \
./data/ImageNet/ \
> logs/RESNET50_IMAGENET_BAM.log
11 changes: 11 additions & 0 deletions scripts/train_imagenet_resnet50_cbam.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python train_imagenet.py \
--ngpu 4 \
--workers 20 \
--arch resnet --depth 50 \
--epochs 100 \
--batch-size 256 --lr 0.1 \
--att-type CBAM \
--prefix RESNET50_IMAGENET_CBAM \
./data/ImageNet/ \
> logs/RESNET50_IMAGENET_CBAM.log
Loading

0 comments on commit e4ee180

Please sign in to comment.