Skip to content

Commit

Permalink
add options for anchors; fix alignment issue with resnet maxpool; mor…
Browse files Browse the repository at this point in the history
…e documentation.
  • Loading branch information
Xinlei Chen committed Mar 30, 2017
1 parent b08edc0 commit 069f163
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 60 deletions.
1 change: 1 addition & 0 deletions experiments/cfgs/res101.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ TRAIN:
BG_THRESH_LO: 0.0
DISPLAY: 20
BATCH_SIZE: 256
WEIGHT_DECAY: 0.0001
SNAPSHOT_PREFIX: res101_faster_rcnn
TEST:
HAS_RPN: True
Expand Down
7 changes: 5 additions & 2 deletions experiments/scripts/test_faster_rcnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ case ${DATASET} in
TRAIN_IMDB="voc_2007_trainval"
TEST_IMDB="voc_2007_test"
ITERS=70000
ANCHORS="[8,16,32]"
;;
pascal_voc_0712)
TRAIN_IMDB="voc_2007_trainval+voc_2012_trainval"
TEST_IMDB="voc_2007_test"
ITERS=70000
ANCHORS="[8,16,32]"
;;
coco)
TRAIN_IMDB="coco_2014_train+coco_2014_valminusminival"
TEST_IMDB="coco_2014_minival"
ITERS=490000
ANCHORS="[4,8,16,32]"
;;
*)
echo "No dataset given"
Expand All @@ -55,13 +58,13 @@ if [[ ! -z ${EXTRA_ARGS_SLUG} ]]; then
--cfg experiments/cfgs/${NET}.yml \
--tag ${EXTRA_ARGS_SLUG} \
--net ${NET} \
--set ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} ${EXTRA_ARGS}
else
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/test_net.py \
--imdb ${TEST_IMDB} \
--model ${NET_FINAL} \
--cfg experiments/cfgs/${NET}.yml \
--net ${NET} \
--set ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} ${EXTRA_ARGS}
fi

7 changes: 5 additions & 2 deletions experiments/scripts/test_vgg16.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ case ${DATASET} in
TRAIN_IMDB="voc_2007_trainval"
TEST_IMDB="voc_2007_test"
ITERS=70000
ANCHORS="[8,16,32]"
;;
pascal_voc_0712)
TRAIN_IMDB="voc_2007_trainval+voc_2012_trainval"
TEST_IMDB="voc_2007_test"
ITERS=70000
ANCHORS="[8,16,32]"
;;
coco)
TRAIN_IMDB="coco_2014_train+coco_2014_valminusminival"
TEST_IMDB="coco_2014_minival"
ITERS=490000
ANCHORS="[4,8,16,32]"
;;
*)
echo "No dataset given"
Expand All @@ -54,13 +57,13 @@ if [[ ! -z ${EXTRA_ARGS_SLUG} ]]; then
--model ${NET_FINAL} \
--cfg experiments/cfgs/vgg16_depre.yml \
--tag ${EXTRA_ARGS_SLUG} \
--set ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} ${EXTRA_ARGS}
else
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/test_vgg16_net.py \
--imdb ${TEST_IMDB} \
--weight data/imagenet_weights/vgg16.weights \
--model ${NET_FINAL} \
--cfg experiments/cfgs/vgg16_depre.yml \
--set ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} ${EXTRA_ARGS}
fi

7 changes: 5 additions & 2 deletions experiments/scripts/train_faster_rcnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ case ${DATASET} in
TEST_IMDB="voc_2007_test"
STEPSIZE=50000
ITERS=70000
ANCHORS="[8,16,32]"
;;
pascal_voc_0712)
TRAIN_IMDB="voc_2007_trainval+voc_2012_trainval"
TEST_IMDB="voc_2007_test"
STEPSIZE=50000
ITERS=70000
ANCHORS="[8,16,32]"
;;
coco)
TRAIN_IMDB="coco_2014_train+coco_2014_valminusminival"
TEST_IMDB="coco_2014_minival"
STEPSIZE=350000
ITERS=490000
ANCHORS="[4,8,16,32]"
;;
*)
echo "No dataset given"
Expand Down Expand Up @@ -61,7 +64,7 @@ if [ ! -f ${NET_FINAL}.index ]; then
--cfg experiments/cfgs/${NET}.yml \
--tag ${EXTRA_ARGS_SLUG} \
--net ${NET} \
--set TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
else
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_net.py \
--weight data/imagenet_weights/${NET}.ckpt \
Expand All @@ -70,7 +73,7 @@ if [ ! -f ${NET_FINAL}.index ]; then
--iters ${ITERS} \
--cfg experiments/cfgs/${NET}.yml \
--net ${NET} \
--set TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
fi
fi

Expand Down
7 changes: 5 additions & 2 deletions experiments/scripts/vgg16.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@ case ${DATASET} in
TEST_IMDB="voc_2007_test"
STEPSIZE=50000
ITERS=70000
ANCHORS="[8,16,32]"
;;
pascal_voc_0712)
TRAIN_IMDB="voc_2007_trainval+voc_2012_trainval"
TEST_IMDB="voc_2007_test"
STEPSIZE=50000
ITERS=70000
ANCHORS="[8,16,32]"
;;
coco)
TRAIN_IMDB="coco_2014_train+coco_2014_valminusminival"
TEST_IMDB="coco_2014_minival"
STEPSIZE=350000
ITERS=490000
ANCHORS="[4,8,16,32]"
;;
*)
echo "No dataset given"
Expand Down Expand Up @@ -59,15 +62,15 @@ if [ ! -f ${NET_FINAL}.index ]; then
--iters ${ITERS} \
--cfg experiments/cfgs/vgg16_depre.yml \
--tag ${EXTRA_ARGS_SLUG} \
--set TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
else
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_vgg16_net.py \
--weight data/imagenet_weights/vgg16.weights \
--imdb ${TRAIN_IMDB} \
--imdbval ${TEST_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/vgg16_depre.yml \
--set TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
--set ANCHOR_SCALES ${ANCHORS} TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
fi
fi

Expand Down
13 changes: 9 additions & 4 deletions lib/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,14 @@

__C.RESNET = edict()

# If max-pooling is appended after crop_and_resize, if true, the region will be resized
# to a squre of 2xPOOLING_SIZE, then 2x2 max-pooling is applied; otherwise
# the region will be directly resized to a square of POOLING_SIZE
# Option to set if max-pooling is appended after crop_and_resize.
# if true, the region will be resized to a squre of 2xPOOLING_SIZE,
# then 2x2 max-pooling is applied; otherwise the region will be directly
# resized to a square of POOLING_SIZE
__C.RESNET.MAX_POOL = False

# Number of fixed blocks during finetuning, by default the first of all 4 blocks is fixed
# Range: 0 (none) to 3 (all)
__C.RESNET.FIXED_BLOCKS = 1

# Whether to tune the batch nomalization parameters during training
Expand Down Expand Up @@ -253,14 +255,17 @@
__C.USE_GPU_NMS = True

# Default GPU device id
__C.GPU_ID = 0
# __C.GPU_ID = 0

# Default pooling mode, only 'crop' is available
__C.POOLING_MODE = 'crop'

# Size of the pooled region after RoI pooling
__C.POOLING_SIZE = 7

# Anchor scales for RPN
__C.ANCHOR_SCALES = [8,16,32]


def get_output_dir(imdb, weights_filename):
"""Return the directory where experimental artifacts are placed.
Expand Down
14 changes: 5 additions & 9 deletions lib/model/train_val.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen
# Written by Xinlei Chen and Zheqi He
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
Expand Down Expand Up @@ -97,17 +97,12 @@ def train_model(self, sess, max_iters):
self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

# Determine different scales for anchors, see paper
if self.imdb.name.startswith('voc'):
anchors = [8, 16, 32]
else:
anchors = [4, 8, 16, 32]

with sess.graph.as_default():
# Set the random seed for tensorflow
tf.set_random_seed(cfg.RNG_SEED)
# Build the main computation graph
layers = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes,
tag='default', anchor_scales=anchors)
tag='default', anchor_scales=cfg.ANCHOR_SCALES)
# Define the loss
loss = layers['total_loss']
# Set learning rate and momentum
Expand Down Expand Up @@ -156,7 +151,7 @@ def train_model(self, sess, max_iters):
ss_paths = sfiles

if lsf == 0:
# Fresh train directly from VGG weights
# Fresh train directly from ImageNet weights
print('Loading initial model weights from {:s}'.format(self.pretrained_model))
variables = tf.global_variables()

Expand All @@ -172,7 +167,8 @@ def train_model(self, sess, max_iters):
var_to_dic[v.name] = v
continue
if v.name.split(':')[0] in var_keep_dic:
variables_to_restore.append(v)
print('Varibles restored: %s' % v.name)
variables_to_restore.append(v)

restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, self.pretrained_model)
Expand Down
8 changes: 1 addition & 7 deletions lib/model/train_val_vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,13 @@ def train_model(self, sess, max_iters):
self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

# Determine different scales for anchors, see paper
if self.imdb.name.startswith('voc'):
anchors = [8, 16, 32]
else:
anchors = [4, 8, 16, 32]

with sess.graph.as_default():
# Set the random seed for tensorflow
tf.set_random_seed(cfg.RNG_SEED)
# Build the main computation graph
layers = self.net.create_architecture(sess, "TRAIN", self.imdb.num_classes,
caffe_weight_path=self.pretrained_model,
tag='default', anchor_scales=anchors)
tag='default', anchor_scales=cfg.ANCHOR_SCALES)
# Define the loss
loss = layers['total_loss']

Expand Down
3 changes: 2 additions & 1 deletion lib/nets/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def _crop_pool_layer(self, bottom, rois, name):
y1 = tf.slice(rois, [0, 2], [-1, 1], name="y1") / height
x2 = tf.slice(rois, [0, 3], [-1, 1], name="x2") / width
y2 = tf.slice(rois, [0, 4], [-1, 1], name="y2") / height
bboxes = tf.concat([y1, x1, y2, x2], axis=1)
# Won't be backpropagated to rois anyway, but to save time
bboxes = tf.stop_gradient(tf.concat([y1, x1, y2, x2], axis=1))
pre_pool_size = cfg.POOLING_SIZE * 2
crops = tf.image.crop_and_resize(bottom, bboxes, tf.to_int32(batch_ids), [pre_pool_size, pre_pool_size], name="crops")

Expand Down
59 changes: 39 additions & 20 deletions lib/nets/res101.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
from model.config import cfg

def resnet_arg_scope(is_training=True,
weight_decay=0.0001,
weight_decay=cfg.TRAIN.WEIGHT_DECAY,
batch_norm_decay=0.997,
batch_norm_epsilon=1e-5,
batch_norm_scale=True):
batch_norm_params = {
# NOTE 'is_training' here does not work because inside resnet it gets reset:
# https://github.com/tensorflow/models/blob/master/slim/nets/resnet_v1.py#L187
'is_training': False,
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
Expand All @@ -55,15 +57,8 @@ def resnet_arg_scope(is_training=True,
activation_fn=nn_ops.relu,
normalizer_fn=layers.batch_norm,
normalizer_params=batch_norm_params):
with arg_scope([layers.batch_norm], **batch_norm_params):
# The following implies padding='SAME' for pool1, which makes feature
# alignment easier for dense prediction tasks. This is also used in
# https://github.com/facebook/fb.resnet.torch. However the accompanying
# code of 'Deep Residual Learning for Image Recognition' uses
# padding='VALID' for pool1. You can switch to that choice by setting
# tf.contrib.framework.arg_scope([tf.contrib.layers.max_pool2d], padding='VALID').
with arg_scope([layers.max_pool2d], padding='SAME') as arg_sc:
return arg_sc
with arg_scope([layers.batch_norm], **batch_norm_params) as arg_sc:
return arg_sc

class Resnet101(Network):
def __init__(self, batch_size=1):
Expand All @@ -81,7 +76,8 @@ def _crop_pool_layer(self, bottom, rois, name):
y1 = tf.slice(rois, [0, 2], [-1, 1], name="y1") / height
x2 = tf.slice(rois, [0, 3], [-1, 1], name="x2") / width
y2 = tf.slice(rois, [0, 4], [-1, 1], name="y2") / height
bboxes = tf.concat([y1, x1, y2, x2], 1)
# Won't be backpropagated to rois anyway, but to save time
bboxes = tf.stop_gradient(tf.concat([y1, x1, y2, x2], 1))
if cfg.RESNET.MAX_POOL:
pre_pool_size = cfg.POOLING_SIZE * 2
crops = tf.image.crop_and_resize(bottom, bboxes, tf.to_int32(batch_ids), [pre_pool_size, pre_pool_size], name="crops")
Expand All @@ -91,6 +87,16 @@ def _crop_pool_layer(self, bottom, rois, name):

return crops

# Do the first few layers manually, because 'SAME' padding can behave inconsistently
# for images of different sizes: sometimes 0, sometimes 1
def build_base(self):
with tf.variable_scope('resnet_v1_101', 'resnet_v1_101'):
net = resnet_utils.conv2d_same(self._image, 64, 7, stride=2, scope='conv1')
net = tf.pad(net, [[0, 0], [1, 1], [1, 1], [0, 0]])
net = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', scope='pool1')

return net

def build_network(self, sess, is_training=True):
# select initializers
if cfg.TRAIN.TRUNCATED:
Expand All @@ -105,29 +111,41 @@ def build_network(self, sess, is_training=True):
[(256, 64, 1)] * 2 + [(256, 64, 2)]),
resnet_utils.Block('block2', bottleneck,
[(512, 128, 1)] * 3 + [(512, 128, 2)]),
# Use stride-1 for the last conv4 layer
resnet_utils.Block('block3', bottleneck,
[(1024, 256, 1)] * 22 + [(1024, 256, 1)]),
resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3)
]
if cfg.RESNET.FIXED_BLOCKS > 0:
assert(cfg.RESNET.FIXED_BLOCKS < 4 and cfg.RESNET.FIXED_BLOCKS >= 0)
if cfg.RESNET.FIXED_BLOCKS == 3:
with slim.arg_scope(resnet_arg_scope(is_training=False)):
net, _ = resnet_v1.resnet_v1(self._image,
blocks[0:cfg.RESNET.FIXED_BLOCKS],
global_pool=False,
include_root_block=True,
scope='resnet_v1_101')
net = self.build_base()
net_conv5, _ = resnet_v1.resnet_v1(net,
blocks[0:cfg.RESNET.FIXED_BLOCKS],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
elif cfg.RESNET.FIXED_BLOCKS > 0:
with slim.arg_scope(resnet_arg_scope(is_training=False)):
net = self.build_base()
net, _ = resnet_v1.resnet_v1(net,
blocks[0:cfg.RESNET.FIXED_BLOCKS],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
with slim.arg_scope(resnet_arg_scope(is_training=is_training)):
net_conv5, _ = resnet_v1.resnet_v1(net,
blocks[cfg.RESNET.FIXED_BLOCKS:-1],
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')
else:
else: # cfg.RESNET.FIXED_BLOCKS == 0
with slim.arg_scope(resnet_arg_scope(is_training=is_training)):
net_conv5, _ = resnet_v1.resnet_v1(self._image,
net = self.build_base()
net_conv5, _ = resnet_v1.resnet_v1(net,
blocks[0:-1],
global_pool=False,
include_root_block=True,
include_root_block=False,
scope='resnet_v1_101')

self._act_summaries.append(net_conv5)
Expand Down Expand Up @@ -177,6 +195,7 @@ def build_network(self, sess, is_training=True):
global_pool=False,
include_root_block=False,
scope='resnet_v1_101')

with tf.variable_scope('resnet_v1_101', 'resnet_v1_101',
regularizer=tf.contrib.layers.l2_regularizer(cfg.TRAIN.WEIGHT_DECAY)):
# Average pooling done by reduce_mean
Expand Down
Loading

0 comments on commit 069f163

Please sign in to comment.