Skip to content

Commit

Permalink
fix link.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinlei Chen committed Sep 14, 2017
1 parent 7c058de commit d726834
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 34 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Approximate *baseline* [setup](https://github.com/endernewton/tf-faster-rcnn/blo
- For COCO, we find the performance improving with more iterations (VGG16 350k/490k: 26.9, 600k/790k: 28.3, 900k/1190k: 29.5), and potentially better performance can be achieved with even more iterations.
- For Resnets, we fix the first block (total 4) when fine-tuning the network, and only use ``crop_and_resize`` to resize the RoIs (7x7) without max-pool (which I find useless especially for COCO). The final feature maps are average-pooled for classification and regression. All batch normalization parameters are fixed. Weight decay is set to Renset101 default 1e-4. Learning rate for biases is not doubled.
- For approximate [FPN](https://arxiv.org/abs/1612.03144) baseline setup we simply resize the image with 800 pixels, add 32^2 anchors, and take 1000 proposals during testing.
- Check out [here](http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/)/[here](http://gs11655.sp.cs.cmu.edu/xinleic/tf-faster-rcnn/)/[here](https://drive.google.com/open?id=0B1_fAEgxdnvJSmF3YUlZcHFqWTQ) for the latest models, including longer COCO VGG16 models and Resnet ones.
- Check out [here](http://ladoga.graphics.cs.cmu.edu/xinleic/tf-faster-rcnn/)/[here](http://xinlei.sp.cs.cmu.edu/xinleic/tf-faster-rcnn/)/[here](https://drive.google.com/open?id=0B1_fAEgxdnvJSmF3YUlZcHFqWTQ) for the latest models, including longer COCO VGG16 models and Resnet ones.

### Additional features
Additional features not mentioned in the [report](https://arxiv.org/pdf/1702.02138.pdf) are added to make research life easier:
Expand Down Expand Up @@ -99,7 +99,7 @@ If you find it useful, the ``data/cache`` folder created on my side is also shar
./data/scripts/fetch_faster_rcnn_models.sh
```
**Note**: if you cannot download the models through the link, or you want to try more models, you can check out the following solutions and optionally update the downloading script:
- Another server [here](http://gs11655.sp.cs.cmu.edu/xinleic/tf-faster-rcnn/).
- Another server [here](http://xinlei.sp.cs.cmu.edu/xinleic/tf-faster-rcnn/).
- Google drive [here](https://drive.google.com/open?id=0B1_fAEgxdnvJSmF3YUlZcHFqWTQ).

2. Create a folder and a softlink to use the pre-trained model
Expand Down
7 changes: 7 additions & 0 deletions experiments/scripts/test_faster_rcnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ case ${DATASET} in
ANCHORS="[8,16,32]"
RATIOS="[0.5,1,2]"
;;
pascal_voc_diff)
TRAIN_IMDB="voc_2007_trainval"
TEST_IMDB="voc_2007_test_diff"
ITERS=70000
ANCHORS="[8,16,32]"
RATIOS="[0.5,1,2]"
;;
pascal_voc_0712)
TRAIN_IMDB="voc_2007_trainval+voc_2012_trainval"
TEST_IMDB="voc_2007_test"
Expand Down
56 changes: 32 additions & 24 deletions experiments/scripts/train_faster_rcnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ case ${DATASET} in
ANCHORS="[8,16,32]"
RATIOS="[0.5,1,2]"
;;
pascal_voc_diff)
TRAIN_IMDB="voc_2007_trainval"
TEST_IMDB="voc_2007_test_diff"
STEPSIZE="[50000]"
ITERS=70000
ANCHORS="[8,16,32]"
RATIOS="[0.5,1,2]"
;;
pascal_voc_0712)
TRAIN_IMDB="voc_2007_trainval+voc_2012_trainval"
TEST_IMDB="voc_2007_test"
Expand Down Expand Up @@ -51,35 +59,35 @@ echo Logging output to "$LOG"

set +x
if [[ ! -z ${EXTRA_ARGS_SLUG} ]]; then
NET_FINAL=output/${NET}/${TRAIN_IMDB}/${EXTRA_ARGS_SLUG}/${NET}_faster_rcnn_iter_${ITERS}.ckpt
NET_FINAL=output/${NET}/${TRAIN_IMDB}/${EXTRA_ARGS_SLUG}/${NET}_faster_rcnn_iter_${ITERS}.ckpt
else
NET_FINAL=output/${NET}/${TRAIN_IMDB}/default/${NET}_faster_rcnn_iter_${ITERS}.ckpt
NET_FINAL=output/${NET}/${TRAIN_IMDB}/default/${NET}_faster_rcnn_iter_${ITERS}.ckpt
fi
set -x

if [ ! -f ${NET_FINAL}.index ]; then
if [[ ! -z ${EXTRA_ARGS_SLUG} ]]; then
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_net.py \
--weight data/imagenet_weights/${NET}.ckpt \
--imdb ${TRAIN_IMDB} \
--imdbval ${TEST_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/${NET}.yml \
--tag ${EXTRA_ARGS_SLUG} \
--net ${NET} \
--set ANCHOR_SCALES ${ANCHORS} ANCHOR_RATIOS ${RATIOS} \
TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
else
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_net.py \
--weight data/imagenet_weights/${NET}.ckpt \
--imdb ${TRAIN_IMDB} \
--imdbval ${TEST_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/${NET}.yml \
--net ${NET} \
--set ANCHOR_SCALES ${ANCHORS} ANCHOR_RATIOS ${RATIOS} \
TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
fi
if [[ ! -z ${EXTRA_ARGS_SLUG} ]]; then
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_net.py \
--weight data/imagenet_weights/${NET}.ckpt \
--imdb ${TRAIN_IMDB} \
--imdbval ${TEST_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/${NET}.yml \
--tag ${EXTRA_ARGS_SLUG} \
--net ${NET} \
--set ANCHOR_SCALES ${ANCHORS} ANCHOR_RATIOS ${RATIOS} \
TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
else
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_net.py \
--weight data/imagenet_weights/${NET}.ckpt \
--imdb ${TRAIN_IMDB} \
--imdbval ${TEST_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/${NET}.yml \
--net ${NET} \
--set ANCHOR_SCALES ${ANCHORS} ANCHOR_RATIOS ${RATIOS} \
TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
fi
fi

./experiments/scripts/test_faster_rcnn.sh $@
5 changes: 5 additions & 0 deletions lib/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: pascal_voc(split, year))

for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}_diff'.format(year, split)
__sets[name] = (lambda split=split, year=year: pascal_voc(split, year, use_diff=True))

# Set up coco_2014_<split>
for year in ['2014']:
for split in ['train', 'val', 'minival', 'valminusminival', 'trainval']:
Expand Down
14 changes: 8 additions & 6 deletions lib/datasets/pascal_voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@


class pascal_voc(imdb):
def __init__(self, image_set, year, devkit_path=None):
imdb.__init__(self, 'voc_' + year + '_' + image_set)
def __init__(self, image_set, year, use_diff=False):
name = 'voc_' + year + '_' + image_set
if use_diff:
name += '_diff'
imdb.__init__(self, name)
self._year = year
self._image_set = image_set
self._devkit_path = self._get_default_path() if devkit_path is None \
else devkit_path
self._devkit_path = self._get_default_path()
self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
self._classes = ('__background__', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
Expand All @@ -48,7 +50,7 @@ def __init__(self, image_set, year, devkit_path=None):
# PASCAL specific config options
self.config = {'cleanup': True,
'use_salt': True,
'use_diff': False,
'use_diff': use_diff,
'matlab_eval': False,
'rpn_file': None}

Expand Down Expand Up @@ -241,7 +243,7 @@ def _do_python_eval(self, output_dir='output'):
filename = self._get_voc_results_file_template().format(cls)
rec, prec, ap = voc_eval(
filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
use_07_metric=use_07_metric)
use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
aps += [ap]
print(('AP for {} = {:.4f}'.format(cls, ap)))
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
Expand Down
8 changes: 6 additions & 2 deletions lib/datasets/voc_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def voc_eval(detpath,
classname,
cachedir,
ovthresh=0.5,
use_07_metric=False):
use_07_metric=False,
use_diff=False):
"""rec, prec, ap = voc_eval(detpath,
annopath,
imagesetfile,
Expand Down Expand Up @@ -133,7 +134,10 @@ def voc_eval(detpath,
for imagename in imagenames:
R = [obj for obj in recs[imagename] if obj['name'] == classname]
bbox = np.array([x['bbox'] for x in R])
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
if use_diff:
difficult = np.array([False for x in R]).astype(np.bool)
else:
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
det = [False] * len(R)
npos = npos + sum(~difficult)
class_recs[imagename] = {'bbox': bbox,
Expand Down

0 comments on commit d726834

Please sign in to comment.