Skip to content

Commit

Permalink
Merge pull request #2 from bhpfelix/add_bn
Browse files Browse the repository at this point in the history
Add bn
  • Loading branch information
bhpfelix authored Jun 28, 2019
2 parents be50e5a + 3ed3966 commit e8843c8
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 45 deletions.
2 changes: 1 addition & 1 deletion configs/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_C.MODEL.NET1_CLASSES = 40
_C.MODEL.NET2_CLASSES = 3

_C.MODEL.BN_BEFORE_RELU = False
_C.MODEL.BN_BEFORE_RELU = True


_C.TRAIN = CN()
Expand Down
6 changes: 0 additions & 6 deletions configs/vgg16_nddr_sing_bn_before_relu.yaml

This file was deleted.

8 changes: 0 additions & 8 deletions configs/vgg16_nddr_sing_flat_nddr_fc8_lr_3000_steps.yaml

This file was deleted.

This file was deleted.

10 changes: 6 additions & 4 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from data.loader import MultiTaskDataset

from models.nddr_net import NDDRNet
from models.vgg16_lfov import DeepLabLargeFOV
from models.vgg16_lfov_bn import DeepLabLargeFOVBN

from utils.metrics import compute_hist, compute_angle

Expand Down Expand Up @@ -85,9 +85,11 @@ def main():
),
batch_size=cfg.TEST.BATCH_SIZE, shuffle=False)

net1 = DeepLabLargeFOV(3, cfg.MODEL.NET1_CLASSES, weights='')
net2 = DeepLabLargeFOV(3, cfg.MODEL.NET2_CLASSES, weights='')
model = NDDRNet(net1, net2)
net1 = DeepLabLargeFOVBN(3, cfg.MODEL.NET1_CLASSES, weights='')
net2 = DeepLabLargeFOVBN(3, cfg.MODEL.NET2_CLASSES, weights='')
model = NDDRNet(net1, net2,
shortcut=cfg.MODEL.SHORTCUT,
bn_before_relu=cfg.MODEL.BN_BEFORE_RELU)
ckpt_path = os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME, 'ckpt-%s.pth' % str(cfg.TEST.CKPT_ID).zfill(5))
print("Evaluating Checkpoint at %s" % ckpt_path)
ckpt = torch.load(ckpt_path)
Expand Down
25 changes: 13 additions & 12 deletions models/nddr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@

class NDDR(nn.Module):
def __init__(self, out_channels, init_weights=[0.9, 0.1], init_method='constant', activation='relu',
batch_norm=True, bn_before_relu=False):
batch_norm=True, bn_before_relu=True, conv_bias=False):
super(NDDR, self).__init__()
self.conv1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1)
self.conv2 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1)
self.conv1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=conv_bias)
self.conv2 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=conv_bias)
if init_method == 'constant':
self.conv1.weight = nn.Parameter(torch.cat([
torch.eye(out_channels) * init_weights[0],
torch.eye(out_channels) * init_weights[1]
], dim=1).view(out_channels, -1, 1, 1))
self.conv1.bias.data.fill_(0)
if self.conv1.bias:
self.conv1.bias.data.fill_(0)
self.conv2.weight = nn.Parameter(torch.cat([
torch.eye(out_channels) * init_weights[1],
torch.eye(out_channels) * init_weights[0]
], dim=1).view(out_channels, -1, 1, 1))
self.conv2.bias.data.fill_(0)
if self.conv2.bias:
self.conv2.bias.data.fill_(0)
elif init_method == 'xavier':
nn.init.xavier_uniform_(self.conv1.weight)
nn.init.xavier_uniform_(self.conv2.weight)
Expand All @@ -37,8 +39,8 @@ def __init__(self, out_channels, init_weights=[0.9, 0.1], init_method='constant'
self.batch_norm = batch_norm
self.bn_before_relu = bn_before_relu
if batch_norm:
self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.05)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.05)
self.bn1 = nn.BatchNorm2d(out_channels, eps=1e-03, momentum=0.05)
self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-03, momentum=0.05)

def forward(self, feature1, feature2):
x = torch.cat([feature1, feature2], 1)
Expand All @@ -57,7 +59,7 @@ def forward(self, feature1, feature2):


class NDDRNet(nn.Module):
def __init__(self, net1, net2, init_weights=[0.9, 0.1], init_method='constant', activation='relu', batch_norm=True, shortcut=False, bn_before_relu=False):
def __init__(self, net1, net2, init_weights=[0.9, 0.1], init_method='constant', activation='relu', batch_norm=True, shortcut=False, bn_before_relu=True, conv_bias=False):
super(NDDRNet, self).__init__()
self.net1 = net1
self.net2 = net2
Expand All @@ -70,16 +72,16 @@ def __init__(self, net1, net2, init_weights=[0.9, 0.1], init_method='constant',
out_channels = net1.stages[stage_id].out_channels
assert out_channels == net2.stages[stage_id].out_channels
total_channels += out_channels
nddr = NDDR(out_channels, init_weights, init_method, activation, batch_norm, bn_before_relu)
nddr = NDDR(out_channels, init_weights, init_method, activation, batch_norm, bn_before_relu, conv_bias)
nddrs.append(nddr)
nddrs = nn.ModuleList(nddrs)

self.shortcut = shortcut
final_conv = None
if shortcut:
print("Using shortcut")
conv = nn.Conv2d(total_channels, net1.stages[-1].out_channels, kernel_size=1)
bn = nn.BatchNorm2d(net1.stages[-1].out_channels, momentum=0.05)
conv = nn.Conv2d(total_channels, net1.stages[-1].out_channels, kernel_size=1, bias=conv_bias)
bn = nn.BatchNorm2d(net1.stages[-1].out_channels, eps=1e-03, momentum=0.05)
if bn_before_relu:
print("Using bn before relu")
final_conv = [conv, bn, nn.ReLU()]
Expand Down Expand Up @@ -112,7 +114,6 @@ def forward(self, x):
y = self.nddrs['shortcut'](y)
x = self.net1.head(x)
y = self.net2.head(y)

x = F.interpolate(x, (H, W), mode='bilinear', align_corners=True)
y = F.interpolate(y, (H, W), mode='bilinear', align_corners=True)
return x, y
Expand Down
2 changes: 0 additions & 2 deletions models/vgg16_lfov.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ def __init__(self, in_dim, out_dim, weights='DeepLab', *args, **kwargs):
self.init_weights()

def forward(self, x):
N, C, H, W = x.size()
for stage in self.stages:
x = stage(x)
x = self.head(x)
x = F.interpolate(x, (H, W), mode='bilinear', align_corners=True)
return x

def init_weights(self):
Expand Down
149 changes: 149 additions & 0 deletions models/vgg16_lfov_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from .common_layers import Stage


class DeepLabLargeFOVBN(nn.Module):
def __init__(self, in_dim, out_dim, weights='DeepLab', *args, **kwargs):
super(DeepLabLargeFOVBN, self).__init__(*args, **kwargs)
self.stages = []
layers = []

stage = [
nn.Conv2d(in_dim, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior
nn.MaxPool2d(3, stride=2)
]
layers += stage
self.stages.append(Stage(64, stage))

stage = [
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior
nn.MaxPool2d(3, stride=2)
]
layers += stage
self.stages.append(Stage(128, stage))

stage = [
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior
nn.MaxPool2d(3, stride=2)
]
layers += stage
self.stages.append(Stage(256, stage))

stage = [
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=1, padding=1)
]
layers += stage
self.stages.append(Stage(512, stage))

stage = [
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=False),
nn.BatchNorm2d(512, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=False),
nn.BatchNorm2d(512, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2, bias=False),
nn.BatchNorm2d(512, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=1, padding=1),
# must use count_include_pad=False to make sure result is same as TF
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
]
layers += stage
self.stages.append(Stage(512, stage))
self.stages = nn.ModuleList(self.stages)

self.features = nn.Sequential(*layers)

head = [
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=12, dilation=12, bias=False),
nn.BatchNorm2d(1024, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1024, eps=1e-03, momentum=0.05),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Conv2d(1024, out_dim, kernel_size=1)
]
self.head = nn.Sequential(*head)

self.weights = weights
self.init_weights()

def forward(self, x):
for stage in self.stages:
x = stage(x)
x = self.head(x)
return x

def init_weights(self):
for layer in self.head.children():
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, a=1)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
elif isinstance(layer, nn.BatchNorm2d):
nn.init.constant_(layer.weight, 1)
nn.init.constant_(layer.bias, 0)

if self.weights == 'DeepLab':
pretrained_dict = torch.load('weights/vgg_deeplab_lfov/tf_deeplab.pth')
model_dict = self.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'head.7' not in k}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
self.load_state_dict(model_dict)
elif self.weights == 'Seg':
pretrained_dict = torch.load('weights/nyu_v2/tf_finetune_seg.pth')
self.load_state_dict(pretrained_dict)
elif self.weights == 'Normal':
pretrained_dict = torch.load('weights/nyu_v2/tf_finetune_normal.pth')
self.load_state_dict(pretrained_dict)
elif self.weights == '':
pass
else:
raise NotImplementedError


if __name__ == "__main__":
net = DeepLabLargeFOVBN(3, 10)
in_ten = torch.randn(1, 3, 321, 321)
out = net(in_ten)
print(out.size())
print(net.stages[1])
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from data.loader import MultiTaskDataset

from models.nddr_net import NDDRNet
from models.vgg16_lfov import DeepLabLargeFOV
from models.vgg16_lfov_bn import DeepLabLargeFOVBN

from utils.losses import get_normal_loss

Expand Down Expand Up @@ -80,8 +80,8 @@ def main():
os.makedirs(experiment_log_dir)
writer = SummaryWriter(logdir=experiment_log_dir)

net1 = DeepLabLargeFOV(3, cfg.MODEL.NET1_CLASSES, weights=cfg.TRAIN.WEIGHT_1)
net2 = DeepLabLargeFOV(3, cfg.MODEL.NET2_CLASSES, weights=cfg.TRAIN.WEIGHT_2)
net1 = DeepLabLargeFOVBN(3, cfg.MODEL.NET1_CLASSES, weights=cfg.TRAIN.WEIGHT_1)
net2 = DeepLabLargeFOVBN(3, cfg.MODEL.NET2_CLASSES, weights=cfg.TRAIN.WEIGHT_2)
model = NDDRNet(net1, net2,
shortcut=cfg.MODEL.SHORTCUT,
bn_before_relu=cfg.MODEL.BN_BEFORE_RELU)
Expand Down

0 comments on commit e8843c8

Please sign in to comment.