Skip to content

Commit

Permalink
add resnet v1 models, modify path settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
endernewton committed Apr 30, 2017
1 parent ed9447b commit 0fdb119
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 47 deletions.
1 change: 1 addition & 0 deletions data/scripts/fetch_coco_long_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../" && pwd )"
cd $DIR

FILE=coco_long.tgz
# replace it with gs11655.sp.cs.cmu.edu if ladoga.graphics.cs.cmu.edu does not work
URL=http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/$FILE
CHECKSUM=099f637235d24ec97d9708b3ff66bd7f

Expand Down
1 change: 1 addition & 0 deletions data/scripts/fetch_faster_rcnn_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../" && pwd )"
cd $DIR

FILE=faster_rcnn_models.tgz
# replace it with gs11655.sp.cs.cmu.edu if ladoga.graphics.cs.cmu.edu does not work
URL=http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/$FILE
CHECKSUM=865cdf7350a87ef41d6476e6e33b7212

Expand Down
1 change: 1 addition & 0 deletions data/scripts/fetch_imagenet_weights.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../" && pwd )"
cd $DIR

FILE=imagenet_weights.tgz
# replace it with gs11655.sp.cs.cmu.edu if ladoga.graphics.cs.cmu.edu does not work
URL=http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/$FILE
CHECKSUM=e9772d7c761040f10a67d389336b90ce

Expand Down
17 changes: 17 additions & 0 deletions experiments/cfgs/res50.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
EXP_DIR: res50
TRAIN:
HAS_RPN: True
IMS_PER_BATCH: 1
BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
RPN_POSITIVE_OVERLAP: 0.7
RPN_BATCHSIZE: 256
PROPOSAL_METHOD: gt
BG_THRESH_LO: 0.0
DISPLAY: 20
BATCH_SIZE: 256
WEIGHT_DECAY: 0.0001
DOUBLE_BIAS: False
SNAPSHOT_PREFIX: res101_faster_rcnn
TEST:
HAS_RPN: True
POOLING_MODE: crop
15 changes: 9 additions & 6 deletions lib/model/train_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def train_model(self, sess, max_iters):
var_to_dic[v.name] = v
continue
# exclude the first conv layer to swap RGB to BGR
if v.name == 'vgg_16/conv1/conv1_1/weights:0' or v.name == 'resnet_v1_101/conv1/weights:0':
if v.name == 'vgg_16/conv1/conv1_1/weights:0'
or v.name == 'resnet_v1_50/conv1/weights:0'
or v.name == 'resnet_v1_101/conv1/weights:0'
or v.name == 'resnet_v1_152/conv1/weights:0':
var_to_dic[v.name] = v
continue
if v.name.split(':')[0] in var_keep_dic:
Expand Down Expand Up @@ -200,16 +203,16 @@ def train_model(self, sess, max_iters):
var_to_dic['vgg_16/fc7/weights:0'].get_shape())))
sess.run(tf.assign(var_to_dic['vgg_16/conv1/conv1_1/weights:0'],
tf.reverse(conv1_rgb, [2])))
elif self.net._arch == 'res101':
print('Fix Resnet101 layers..')
with tf.variable_scope('Fix_Res101') as scope:
elif self.net._arch.startswith('res_v1'):
print('Fix Resnet V1 layers..')
with tf.variable_scope('Fix_Resnet_V1') as scope:
with tf.device("/cpu:0"):
# fix RGB to BGR
conv1_rgb = tf.get_variable("conv1_rgb", [7, 7, 3, 64], trainable=False)
restorer_fc = tf.train.Saver({"resnet_v1_101/conv1/weights": conv1_rgb})
restorer_fc = tf.train.Saver({self.net._resnet_scope + "/conv1/weights": conv1_rgb})
restorer_fc.restore(sess, self.pretrained_model)

sess.run(tf.assign(var_to_dic['resnet_v1_101/conv1/weights:0'], tf.reverse(conv1_rgb, [2])))
sess.run(tf.assign(var_to_dic[self.net._resnet_scope + '/conv1/weights:0'], tf.reverse(conv1_rgb, [2])))
else:
# every network should fix the rgb issue at least
raise NotImplementedError
Expand Down
5 changes: 3 additions & 2 deletions lib/nets/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,11 @@ def create_architecture(self, sess, mode, num_classes, tag=None,

return layers_to_output

# Extract the head feature maps, for example for vgg16 it is conv5_3
# only useful during testing mode
def extract_conv5(self, sess, image):
def extract_head(self, sess, image):
feed_dict = {self._image: image}
feat = sess.run(self._layers["conv5_3"], feed_dict=feed_dict)
feat = sess.run(self._layers["head"], feed_dict=feed_dict)
return feat

# only useful during testing mode
Expand Down
86 changes: 58 additions & 28 deletions lib/nets/res101.py → lib/nets/resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def resnet_arg_scope(is_training=True,
with arg_scope([layers.batch_norm], **batch_norm_params) as arg_sc:
return arg_sc

class Resnet101(Network):
def __init__(self, batch_size=1):
class resnetv1(Network):
def __init__(self, batch_size=1, num_layers=50):
Network.__init__(self, batch_size=batch_size)
self._arch = 'res101'
self._num_layers = num_layers
self._arch = 'res_v1_%d' % num_layers
self._resnet_scope = 'resnet_v1_%d' % num_layers

def _crop_pool_layer(self, bottom, rois, name):
with tf.variable_scope(name) as scope:
Expand Down Expand Up @@ -81,7 +83,7 @@ def _crop_pool_layer(self, bottom, rois, name):
# Do the first few layers manually, because 'SAME' padding can behave inconsistently
# for images of different sizes: sometimes 0, sometimes 1
def build_base(self):
with tf.variable_scope('resnet_v1_101', 'resnet_v1_101'):
with tf.variable_scope(self._resnet_scope, self._resnet_scope):
net = resnet_utils.conv2d_same(self._image, 64, 7, stride=2, scope='conv1')
net = tf.pad(net, [[0, 0], [1, 1], [1, 1], [0, 0]])
net = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', scope='pool1')
Expand All @@ -97,57 +99,85 @@ def build_network(self, sess, is_training=True):
initializer = tf.random_normal_initializer(mean=0.0, stddev=0.01)
initializer_bbox = tf.random_normal_initializer(mean=0.0, stddev=0.001)
bottleneck = resnet_v1.bottleneck
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 3 + [(512, 128, 2)]),
# Use stride-1 for the last conv4 layer
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 22 + [(1024, 256, 1)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
# choose different blocks for different number of layers
if self._num_layers == 50:
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 3 + [(512, 128, 2)]),
# Use stride-1 for the last conv4 layer
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 5 + [(1024, 256, 1)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
elif self._num_layers == 101:
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 3 + [(512, 128, 2)]),
# Use stride-1 for the last conv4 layer
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 22 + [(1024, 256, 1)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
elif self._num_layers == 152:
blocks = [
resnet_utils.Block('block1', bottleneck,
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 7 + [(512, 128, 2)]),
# Use stride-1 for the last conv4 layer
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 35 + [(1024, 256, 1)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
else:
# other numbers are not supported
raise NotImplementedError

assert (0 <= cfg.RESNET.FIXED_BLOCKS < 4)
if cfg.RESNET.FIXED_BLOCKS == 3:
with slim.arg_scope(resnet_arg_scope(is_training=False)):
net = self.build_base()
net_conv5, _ = resnet_v1.resnet_v1(net,
net_conv4, _ = resnet_v1.resnet_v1(net,
blocks[0:cfg.RESNET.FIXED_BLOCKS],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
scope=self._resnet_scope)
elif cfg.RESNET.FIXED_BLOCKS > 0:
with slim.arg_scope(resnet_arg_scope(is_training=False)):
net = self.build_base()
net, _ = resnet_v1.resnet_v1(net,
blocks[0:cfg.RESNET.FIXED_BLOCKS],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
scope=self._resnet_scope)

with slim.arg_scope(resnet_arg_scope(is_training=is_training)):
net_conv5, _ = resnet_v1.resnet_v1(net,
net_conv4, _ = resnet_v1.resnet_v1(net,
blocks[cfg.RESNET.FIXED_BLOCKS:-1],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
scope=self._resnet_scope)
else: # cfg.RESNET.FIXED_BLOCKS == 0
with slim.arg_scope(resnet_arg_scope(is_training=is_training)):
net = self.build_base()
net_conv5, _ = resnet_v1.resnet_v1(net,
net_conv4, _ = resnet_v1.resnet_v1(net,
blocks[0:-1],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
scope=self._resnet_scope)

self._act_summaries.append(net_conv5)
self._layers['conv5_3'] = net_conv5
with tf.variable_scope('resnet_v1_101', 'resnet_v1_101'):
self._act_summaries.append(net_conv4)
self._layers['head'] = net_conv4
with tf.variable_scope(self._resnet_scope, self._resnet_scope):
# build the anchors for the image
self._anchor_component()

# rpn
rpn = slim.conv2d(net_conv5, 512, [3, 3], trainable=is_training, weights_initializer=initializer,
rpn = slim.conv2d(net_conv4, 512, [3, 3], trainable=is_training, weights_initializer=initializer,
scope="rpn_conv/3x3")
self._act_summaries.append(rpn)
rpn_cls_score = slim.conv2d(rpn, self._num_anchors * 2, [1, 1], trainable=is_training,
Expand Down Expand Up @@ -176,7 +206,7 @@ def build_network(self, sess, is_training=True):

# rcnn
if cfg.POOLING_MODE == 'crop':
pool5 = self._crop_pool_layer(net_conv5, rois, "pool5")
pool5 = self._crop_pool_layer(net_conv4, rois, "pool5")
else:
raise NotImplementedError

Expand All @@ -185,9 +215,9 @@ def build_network(self, sess, is_training=True):
blocks[-1:],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
scope=self._resnet_scope)

with tf.variable_scope('resnet_v1_101', 'resnet_v1_101'):
with tf.variable_scope(self._resnet_scope, self._resnet_scope):
# Average pooling done by reduce_mean
fc7 = tf.reduce_mean(fc7, axis=[1, 2])
cls_score = slim.fully_connected(fc7, self._num_classes, weights_initializer=initializer,
Expand Down
2 changes: 1 addition & 1 deletion lib/nets/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_network(self, sess, is_training=True):
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3],
trainable=is_training, scope='conv5')
self._act_summaries.append(net)
self._layers['conv5_3'] = net
self._layers['head'] = net
# build the anchors for the image
self._anchor_component()

Expand Down
3 changes: 2 additions & 1 deletion tools/_init_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ def add_path(path):
lib_path = osp.join(this_dir, '..', 'lib')
add_path(lib_path)

add_path('data/coco/PythonAPI')
coco_path = osp.join(this_dir, '..', 'data', 'coco', 'PythonAPI')
add_path(coco_path)
3 changes: 2 additions & 1 deletion tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import argparse

from nets.vgg16 import vgg16
from nets.res101 import Resnet101
from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
Expand All @@ -39,6 +39,7 @@

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

def vis_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
Expand Down
12 changes: 8 additions & 4 deletions tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import tensorflow as tf
from nets.vgg16 import vgg16
from nets.res101 import Resnet101
from nets.resnet_v1 import resnetv1

def parse_args():
"""
Expand All @@ -41,8 +41,8 @@ def parse_args():
help='tag of the model',
default='', type=str)
parser.add_argument('--net', dest='net',
help='vgg16 or res101',
default='res101', type=str)
help='vgg16, res50, res101, res152',
default='res50', type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
Expand Down Expand Up @@ -90,8 +90,12 @@ def parse_args():
# load network
if args.net == 'vgg16':
net = vgg16(batch_size=1)
elif args.net == 'res50':
net = resnetv1(batch_size=1, num_layers=50)
elif args.net == 'res101':
net = Resnet101(batch_size=1)
net = resnetv1(batch_size=1, num_layers=101)
elif args.net == 'res152':
net = resnetv1(batch_size=1, num_layers=152)
else:
raise NotImplementedError

Expand Down
13 changes: 9 additions & 4 deletions tools/trainval_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import tensorflow as tf
from nets.vgg16 import vgg16
from nets.res101 import Resnet101
from nets.resnet_v1 import resnetv1

def parse_args():
"""
Expand All @@ -45,8 +45,8 @@ def parse_args():
help='tag of the model',
default=None, type=str)
parser.add_argument('--net', dest='net',
help='vgg16 or res101',
default='res101', type=str)
help='vgg16, res50, res101, res152',
default='res50', type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
Expand Down Expand Up @@ -119,10 +119,15 @@ def get_roidb(imdb_name):
print('{:d} validation roidb entries'.format(len(valroidb)))
cfg.TRAIN.USE_FLIPPED = orgflip

# load network
if args.net == 'vgg16':
net = vgg16(batch_size=cfg.TRAIN.IMS_PER_BATCH)
elif args.net == 'res50':
net = resnetv1(batch_size=cfg.TRAIN.IMS_PER_BATCH, num_layers=50)
elif args.net == 'res101':
net = Resnet101(batch_size=cfg.TRAIN.IMS_PER_BATCH)
net = resnetv1(batch_size=cfg.TRAIN.IMS_PER_BATCH, num_layers=101)
elif args.net == 'res152':
net = resnetv1(batch_size=cfg.TRAIN.IMS_PER_BATCH, num_layers=152)
else:
raise NotImplementedError

Expand Down

0 comments on commit 0fdb119

Please sign in to comment.