Skip to content

Commit

Permalink
v0.4.3 (zhanghang1989#71)
Browse files Browse the repository at this point in the history
- ADE20K training model
- Amazon legal approval

fixes zhanghang1989#69
  • Loading branch information
zhanghang1989 committed Jun 15, 2018
1 parent 9bc7053 commit 32e382b
Show file tree
Hide file tree
Showing 26 changed files with 610 additions and 114 deletions.
15 changes: 12 additions & 3 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
MIT License

Copyright (c) 2017 Hang Zhang
Copyright (c) 2017 Hang Zhang. All rights reserved.
Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand All @@ -9,8 +10,16 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.

3. Neither the name of Amazon Inc nor the names of the contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
Expand Down
39 changes: 29 additions & 10 deletions docs/source/experiments/segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ Test Pre-trained Model
for example ``Encnet_ResNet50_PContext``::

python test.py --dataset PContext --model-zoo Encnet_ResNet50_PContext --eval
# pixAcc: 0.7862, mIoU: 0.4946: 100%|████████████████████████| 319/319 [09:44<00:00, 1.83s/it]
# pixAcc: 0.7838, mIoU: 0.4958: 100%|████████████████████████| 1276/1276 [46:31<00:00, 2.19s/it]

The command for training the model can be found by clicking ``cmd`` in the table.

.. role:: raw-html(raw)
:format: html

+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
| Model | pixAcc | mIoU | Command |
+==================================+===========+===========+==============================================================================================+
| FCN_ResNet50_PContext | 76.0% | 45.7 | :raw-html:`<a href="javascript:toggleblock('cmd_fcn50_pcont')" class="toggleblock">cmd</a>` |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
| Encnet_ResNet50_PContext | 78.6% | 49.5 | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
| Encnet_ResNet101_PContext | 80.0% | 52.1 | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| Model | pixAcc | mIoU | Note | Command | Logs |
+==================================+===========+===========+===========+==============================================================================================+============+
| Encnet_ResNet50_PContext | 78.4% | 49.6% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` | ENC50PC_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet101_PContext | 79.9% | 51.8% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` | ENC101PC_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet50_ADE | 79.8% | 41.3% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>` | ENC50ADE_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+

.. _ENC50PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_pcontext.log?raw=true
.. _ENC101PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet101_pcontext.log?raw=true
.. _ENC50ADE: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_ade.log?raw=true


.. raw:: html

Expand All @@ -58,6 +63,14 @@ Test Pre-trained Model
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset PContext --model EncNet --aux --se-loss --backbone resnet101
</code>

<code xml:space="preserve" id="cmd_psp50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model PSP --aux
</code>

<code xml:space="preserve" id="cmd_enc50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss
</code>

Quick Demo
~~~~~~~~~~

Expand Down Expand Up @@ -105,6 +118,12 @@ Train Your Own Model

- Detail training options, please run ``python train.py -h``.

- The validation metrics during the training only using center-crop is just for monitoring the
training correctness purpose. For evaluating the pretrained model on validation set using MS,
please use the command::

CUDA_VISIBLE_DEVICES=0,1,2,3 python test.py --dataset pcontext --model encnet --aux --se-loss --resume mycheckpoint --eval

Citation
--------

Expand Down
5 changes: 4 additions & 1 deletion encoding/datasets/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ def _get_ade20k_pairs(folder, split='train'):
if split == 'train':
img_folder = os.path.join(folder, 'images/training')
mask_folder = os.path.join(folder, 'annotations/training')
else:
elif split == 'val':
img_folder = os.path.join(folder, 'images/validation')
mask_folder = os.path.join(folder, 'annotations/validation')
else:
img_folder = os.path.join(folder, 'images/trainval')
mask_folder = os.path.join(folder, 'annotations/trainval')
for filename in os.listdir(img_folder):
basename, _ = os.path.splitext(filename)
if filename.endswith(".jpg"):
Expand Down
126 changes: 126 additions & 0 deletions encoding/datasets/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
from tqdm import trange
from PIL import Image, ImageOps, ImageFilter
import numpy as np
import torch

from .base import BaseDataset

"""
NUM_CHANNEL = 91
[] background
[5] airplane
[2] bicycle
[16] bird
[9] boat
[44] bottle
[6] bus
[3] car
[17] cat
[62] chair
[21] cow
[67] dining table
[18] dog
[19] horse
[4] motorcycle
[1] person
[64] potted plant
[20] sheep
[63] couch
[7] train
[72] tv
"""
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
1, 64, 20, 63, 7, 72]


class COCOSegmentation(BaseDataset):
def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train',
mode=None, transform=None, target_transform=None):
super(COCOSegmentation, self).__init__(
root, split, mode, transform, target_transform)
from pycocotools.coco import COCO
from pycocotools import mask
if mode == 'train':
print('train set')
ann_file = os.path.join(root, 'coco/annotations/instances_train2014.json')
ids_file = os.path.join(root, 'coco/annotations/train_ids.pth')
root = os.path.join(root, 'coco/train2014')
else:
print('val set')
ann_file = os.path.join(root, 'coco/annotations/instances_val2014.json')
ids_file = os.path.join(root, 'coco/annotations/val_ids.pth')
root = os.path.join(root, 'coco/val2014')
self.coco = COCO(ann_file)
self.coco_mask = mask
if os.path.exists(ids_file):
self.ids = torch.load(ids_file)
else:
ids = list(self.coco.imgs.keys())
self.ids = self._preprocess(ids, ids_file)
self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
coco = self.coco
img_id = self.ids[index]
img_metadata = coco.loadImgs(img_id)[0]
path = img_metadata['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')
cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
mask = Image.fromarray(self._gen_seg_mask(cocotarget, img_metadata['height'],
img_metadata['width']))
# synchrosized transform
if self.mode == 'train':
img, mask = self._sync_transform(img, mask)
elif self.mode == 'val':
img, mask = self._val_sync_transform(img, mask)
else:
assert self.mode == 'testval'
mask = self._mask_transform(mask)
# general resize, normalize and toTensor
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
mask = self.target_transform(mask)
return img, mask

def __len__(self):
return len(self.ids)

def _gen_seg_mask(self, target, h, w):
mask = np.zeros((h, w), dtype=np.uint8)
coco_mask = self.coco_mask
for instance in target:
rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
m = coco_mask.decode(rle)
cat = instance['category_id']
if cat in CAT_LIST:
c = CAT_LIST.index(cat)
else:
continue
if len(m.shape) < 3:
mask[:, :] += (mask == 0) * (m * c)
else:
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
return mask

def _preprocess(self, ids, ids_file):
print("Preprocessing mask, this will take a while." + \
"But don't worry, it only run once for each split.")
tbar = trange(len(ids))
new_ids = []
for i in tbar:
img_id = ids[i]
cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
img_metadata = self.coco.loadImgs(img_id)[0]
mask = self._gen_seg_mask(cocotarget, img_metadata['height'],
img_metadata['width'])
# more than 1k pixels
if (mask > 0).sum() > 1000:
new_ids.append(img_id)
tbar.set_description('Doing: {}/{}, got {} qualified images'.\
format(i, len(ids), len(new_ids)))
print('Found number of qualified images: ', len(new_ids))
torch.save(new_ids, ids_file)
return new_ids
41 changes: 26 additions & 15 deletions encoding/datasets/pcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from PIL import Image, ImageOps, ImageFilter
import os
import os.path
import math
import random
import numpy as np
from tqdm import trange

import torch
from .base import BaseDataset
Expand All @@ -26,27 +26,24 @@ def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train',
root = os.path.join(root, self.BASE_DIR)
annFile = os.path.join(root, 'trainval_merged.json')
imgDir = os.path.join(root, 'JPEGImages')
mask_file = os.path.join(root, self.split+'.pth')
# training mode
if split == 'train':
phase = 'train'
elif split == 'val':
phase = 'val'
elif split == 'test':
phase = 'val'
#phase = 'test'
print('annFile', annFile)
print('imgDir', imgDir)
self.detail = Detail(annFile, imgDir, phase)
self.detail = Detail(annFile, imgDir, split)
self.transform = transform
self.target_transform = target_transform
self.ids = self.detail.getImgs()
# generate masks
self._mapping = np.sort(np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22,
23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296,
427, 44, 45, 46, 308, 59, 440, 445, 31, 232, 65, 354, 424,
68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360,
98, 187, 104, 105, 366, 189, 368, 113, 115]))
self._key = np.array(range(len(self._mapping))).astype('uint8')
if os.path.exists(mask_file):
self.masks = torch.load(mask_file)
else:
self.masks = self._preprocess(mask_file)

def _class_to_index(self, mask):
# assert the values
Expand All @@ -57,19 +54,33 @@ def _class_to_index(self, mask):
index = np.digitize(mask.ravel(), self._mapping, right=True)
return self._key[index].reshape(mask.shape)

def _preprocess(self, mask_file):
masks = {}
tbar = trange(len(self.ids))
print("Preprocessing mask, this will take a while." + \
"But don't worry, it only run once for each split.")
for i in tbar:
img_id = self.ids[i]
mask = Image.fromarray(self._class_to_index(
self.detail.getMask(img_id)))
masks[img_id['image_id']] = mask
tbar.set_description("Preprocessing masks {}".format(img_id['image_id']))
torch.save(masks, mask_file)
return masks

def __getitem__(self, index):
detail = self.detail
img_id = self.ids[index]
path = img_id['file_name']
iid = img_id['image_id']
img = Image.open(os.path.join(detail.img_folder, path)).convert('RGB')
img = Image.open(os.path.join(self.detail.img_folder, path)).convert('RGB')
if self.mode == 'test':
if self.transform is not None:
img = self.transform(img)
return img, os.path.basename(path)
# convert mask to 60 categories
mask = Image.fromarray(self._class_to_index(
detail.getMask(img_id)))
#mask = Image.fromarray(self._class_to_index(
# self.detail.getMask(img_id)))
mask = self.masks[iid]
# synchrosized transform
if self.mode == 'train':
img, mask = self._sync_transform(img, mask)
Expand Down
16 changes: 10 additions & 6 deletions encoding/dilated/resnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Dilated ResNet"""
import math
import torch
import torch.utils.model_zoo as model_zoo
#from .. import nn
import torch.nn as nn

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
Expand Down Expand Up @@ -234,31 +234,35 @@ def resnet34(pretrained=False, **kwargs):
return model


def resnet50(pretrained=False, **kwargs):
def resnet50(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
from ..models.model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('resnet50', root=root)), strict=False)
return model


def resnet101(pretrained=False, **kwargs):
def resnet101(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
from ..models.model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('resnet101', root=root)), strict=False)
return model


def resnet152(pretrained=False, **kwargs):
def resnet152(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-152 model.
Args:
Expand Down
Loading

0 comments on commit 32e382b

Please sign in to comment.