Skip to content

Commit

Permalink
DeepTen model (zhanghang1989#85)
Browse files Browse the repository at this point in the history
* deepten model

fixes zhanghang1989#77
fixes zhanghang1989#32
  • Loading branch information
zhanghang1989 committed Jun 27, 2018
1 parent b05334f commit 0403d28
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion docs/source/experiments/segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Test Pre-trained Model
</code>

<code xml:space="preserve" id="cmd_enc50_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ade20k --model encnetv2 --aux --se-loss
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss
</code>

Quick Demo
Expand Down
17 changes: 6 additions & 11 deletions docs/source/experiments/texture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,28 @@ Test Pre-trained Model

- Clone the GitHub repo::

git clone git@github.com:zhanghang1989/PyTorch-Encoding.git
git clone https://github.com/zhanghang1989/PyTorch-Encoding

- Install PyTorch Encoding (if not yet). Please follow the installation guide `Installing PyTorch Encoding <../notes/compile.html>`_.

- Download the `MINC-2500 <http://opensurfaces.cs.cornell.edu/publications/minc/>`_ dataset to ``$HOME/data/minc-2500/`` folder. Download pre-trained model (training `curve`_ as bellow, pre-trained on train-1 split using single training size of 224, with an error rate of :math:`19.98\%` using single crop on test-1 set)::
- Download the `MINC-2500 <http://opensurfaces.cs.cornell.edu/publications/minc/>`_ dataset to ``$HOME/data/minc-2500/`` folder. Download pre-trained model (pre-trained on train-1 split using single training size of 224, with an error rate of :math:`19.70\%` using single crop on test-1 set)::

cd PyTorch-Encoding/experiments/recognition
bash model/download_models.sh

.. _curve:

.. image:: ../_static/img/deep_ten_curve.svg
:width: 70%
python model/download_models.py

- Test pre-trained model on MINC-2500::

>>> python main.py --dataset minc --model deepten --nclass 23 --resume model/minc.pth.tar --eval
python main.py --dataset minc --model deepten --nclass 23 --resume deepten_minc.pth --eval
# Teriminal Output:
#[======================================== 23/23 ===================================>...] Step: 104ms | Tot: 3s256ms | Loss: 0.719 | Err: 19.983% (1149/5750)
# Loss: 1.005 | Err: 19.704% (1133/5750): 100%|████████████████████| 23/23 [00:18<00:00, 1.26it/s]


Train Your Own Model
--------------------

- Example training command for training above model::

python main.py --model deepten --nclass 23 --model deepten --batch-size 64 --lr 0.01 --epochs 60
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset minc --model deepten --nclass 23 --model deepten --batch-size 512 --lr 0.004 --epochs 80 --lr-step 60

- Detail training options::

Expand Down
1 change: 1 addition & 0 deletions encoding/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .model_zoo import get_model
from .model_store import get_model_file
from .base import *
from .fcn import *
from .psp import *
Expand Down
1 change: 1 addition & 0 deletions encoding/models/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_model_sha1 = {name: checksum for checksum, name in [
('853f2fb07aeb2927f7696e166b215609a987fd44', 'resnet50'),
('5be5422ad7cb6a2e5f5a54070d0aa9affe69a9a4', 'resnet101'),
('6cb047cda851de6aa31963e779fae5f4c299056a', 'deepten_minc'),
('fc8c0b795abf0133700c2d4265d2f9edab7eb6cc', 'fcn_resnet50_ade'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
Expand Down
6 changes: 3 additions & 3 deletions experiments/recognition/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def train(epoch):
loss.backward()
optimizer.step()

train_loss += loss.data[0]
train_loss += loss.data.item()
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum()
total += target.size(0)
err = 100-100.*correct/total
err = 100.0 - 100.0 * correct / total
tbar.set_description('\rLoss: %.3f | Err: %.3f%% (%d/%d)' % \
(train_loss/(batch_idx+1), err, total-correct, total))

Expand All @@ -122,7 +122,7 @@ def test(epoch):
correct += pred.eq(target.data).cpu().sum().item()
total += target.size(0)

err = 100-100.0*correct/total
err = 100.0 - 100.0 * correct / total
tbar.set_description('Loss: %.3f | Err: %.3f%% (%d/%d)'% \
(test_loss/(batch_idx+1), err, total-correct, total))

Expand Down
8 changes: 4 additions & 4 deletions experiments/recognition/model/deepten.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.autograd import Variable

import encoding
import torchvision.models as resnet
import encoding.dilated.resnet as resnet

class Net(nn.Module):
def __init__(self, args):
Expand All @@ -23,11 +23,11 @@ def __init__(self, args):
self.backbone = args.backbone
# copying modules from pretrained models
if self.backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True)
self.pretrained = resnet.resnet50(pretrained=True, dilated=False)
elif self.backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True)
self.pretrained = resnet.resnet101(pretrained=True, dilated=False)
elif self.backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True)
self.pretrained = resnet.resnet152(pretrained=True, dilated=False)
else:
raise RuntimeError('unknown backbone: {}'.format(self.backbone))
n_codes = 32
Expand Down
5 changes: 5 additions & 0 deletions experiments/recognition/model/download_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import encoding
import shutil

encoding.models.get_model_file('deepten_minc', root='./')
shutil.move('deepten_minc-6cb047cd.pth', 'deepten_minc.pth')
2 changes: 1 addition & 1 deletion experiments/recognition/model/mynn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, inplanes, planes, stride=1,norm_layer=nn.BatchNorm2d):
conv_block = []
conv_block += [norm_layer(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, planes, kernel_size=1,
nn.Conv2d(inplanes, planes, kernel_size=1,
stride=1, bias=False)]
conv_block += [norm_layer(planes),
nn.ReLU(inplace=True),
Expand Down

0 comments on commit 0403d28

Please sign in to comment.