Skip to content

Commit

Permalink
fix path (zhanghang1989#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed Jun 15, 2018
1 parent 2dd88e5 commit 1636365
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
11 changes: 7 additions & 4 deletions encoding/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
mean=[.485, .456, .406], std=[.229, .224, .225]):
mean=[.485, .456, .406], std=[.229, .224, .225], root='~/.encoding/models'):
super(BaseNet, self).__init__()
self.nclass = nclass
self.aux = aux
Expand All @@ -33,11 +33,14 @@ def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None
self.std = std
# copying modules from pretrained models
if backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, norm_layer=norm_layer)
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
elif backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, norm_layer=norm_layer)
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
elif backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, norm_layer=norm_layer)
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
# bilinear upsample options
Expand Down
5 changes: 3 additions & 2 deletions encoding/models/encnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
class EncNet(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
norm_layer=nn.BatchNorm2d, **kwargs):
super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
super(EncNet, self).__init__(nclass, backbone, aux, se_loss,
norm_layer=norm_layer, **kwargs)
self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
lateral=lateral, norm_layer=norm_layer,
up_kwargs=self._up_kwargs)
Expand Down Expand Up @@ -142,7 +143,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
Expand Down
8 changes: 4 additions & 4 deletions encoding/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class FCN(BaseNet):
>>> print(model)
"""
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = FCNHead(2048, nclass, norm_layer)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
Expand Down Expand Up @@ -97,7 +97,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
Expand All @@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_fcn('pcontext', 'resnet50', pretrained, aux=False, **kwargs)
return get_fcn('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)

def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
Expand All @@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_fcn_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_fcn('ade20k', 'resnet50', pretrained, **kwargs)
return get_fcn('ade20k', 'resnet50', pretrained, root=root, **kwargs)
6 changes: 3 additions & 3 deletions encoding/models/psp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class PSP(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = PSPHead(2048, nclass, norm_layer, self._up_kwargs)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
Expand Down Expand Up @@ -59,7 +59,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
Expand All @@ -83,4 +83,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
>>> model = get_psp_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_psp('ade20k', 'resnet50', pretrained)
return get_psp('ade20k', 'resnet50', pretrained, root=root, **kwargs)

0 comments on commit 1636365

Please sign in to comment.