Skip to content

Commit

Permalink
add detail API and other fixes (zhanghang1989#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed Jun 6, 2018
1 parent 3ba8d2f commit 9bc7053
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 50 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ created by [Hang Zhang](http://hangzh.com/)

- Please visit the [**Docs**](http://hangzh.com/PyTorch-Encoding/) for detail instructions of installation and usage.

- How to use Synchronized Batch Normalization (SyncBN)? See the [examples](https://github.com/zhanghang1989/PyTorch-SyncBatchNorm).
- Please visit the [link](http://hangzh.com/PyTorch-Encoding/experiments/segmentation.html) to examples of semantic segmentation.

## Citations

Expand Down
3 changes: 0 additions & 3 deletions encoding/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,4 @@ def test_batchify_fn(data):
elif isinstance(data[0], (tuple, list)):
data = zip(*data)
return [test_batchify_fn(i) for i in data]
elif isinstance(data[0], ):
data = np.asarray(data)
return mx.nd.array(data, dtype=data.dtype)
raise TypeError((error_msg.format(type(batch[0]))))
42 changes: 15 additions & 27 deletions encoding/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

up_kwargs = {'mode': 'bilinear', 'align_corners': True}

__all__ = ['BaseNet', 'EvalModule', 'MultiEvalModule']
__all__ = ['BaseNet', 'MultiEvalModule']

class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
Expand Down Expand Up @@ -65,16 +65,6 @@ def evaluate(self, x, target=None):
return correct, labeled, inter, union


class EvalModule(nn.Module):
"""Segmentation Eval Module"""
def __init__(self, module):
super(EvalModule, self).__init__()
self.module = module

def forward(self, *inputs, **kwargs):
return self.module.evaluate(*inputs, **kwargs)


class MultiEvalModule(DataParallel):
"""Multi-size Segmentation Eavluator"""
def __init__(self, module, nclass, device_ids=None,
Expand Down Expand Up @@ -125,11 +115,11 @@ def forward(self, image):
height = int(1.0 * h * long_size / w + 0.5)
short_size = height
# resize image to current size
cur_img = resize_image(image, height, width)
if scale <= 1.25 or long_size <= crop_size:# #
cur_img = resize_image(image, height, width, **self.module._up_kwargs)
if long_size <= crop_size:
pad_img = pad_image(cur_img, self.module.mean,
self.module.std, crop_size)
outputs = self.module_inference(pad_img)
outputs = module_inference(self.module, pad_img, self.flip)
outputs = crop_image(outputs, 0, height, 0, width)
else:
if short_size < crop_size:
Expand Down Expand Up @@ -157,29 +147,29 @@ def forward(self, image):
# pad if needed
pad_crop_img = pad_image(crop_img, self.module.mean,
self.module.std, crop_size)
output = self.module_inference(pad_crop_img)
output = module_inference(self.module, pad_crop_img, self.flip)
outputs[:,:,h0:h1,w0:w1] += crop_image(output,
0, h1-h0, 0, w1-w0)
count_norm[:,:,h0:h1,w0:w1] += 1
assert((count_norm==0).sum()==0)
outputs = outputs / count_norm
outputs = outputs[:,:,:height,:width]

score = resize_image(outputs, h, w)
score = resize_image(outputs, h, w, **self.module._up_kwargs)
scores += score

return scores

def module_inference(self, image):
output = self.module.evaluate(image)
if self.flip:
fimg = flip_image(image)
foutput = self.module.evaluate(fimg)
output += flip_image(foutput)
return output.exp()

def module_inference(module, image, flip=True):
output = module.evaluate(image)
if flip:
fimg = flip_image(image)
foutput = module.evaluate(fimg)
output += flip_image(foutput)
return output.exp()

def resize_image(img, h, w, mode='bilinear'):
def resize_image(img, h, w, **up_kwargs):
return F.upsample(img, (h, w), **up_kwargs)

def pad_image(img, mean, std, crop_size):
Expand All @@ -189,11 +179,9 @@ def pad_image(img, mean, std, crop_size):
padw = crop_size - w if w < crop_size else 0
pad_values = -np.array(mean) / np.array(std)
img_pad = img.new().resize_(b,c,h+padh,w+padw)
#img_pad = F.pad(img, (0,padw,0,padh))
for i in range(c):
# note that pytorch pad params is in reversed orders
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh),
value=pad_values[i])
img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i])
assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size)
return img_pad

Expand Down
2 changes: 1 addition & 1 deletion encoding/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
>>> model = get_fcn_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_fcn('pcontext', 'resnet50', pretrained)
return get_fcn('pcontext', 'resnet50', pretrained, aux=False)

def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
Expand Down
1 change: 0 additions & 1 deletion encoding/nn/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
__all__ = ['GramMatrix', 'SegmentationLosses', 'View', 'Sum', 'Mean',
'Normalize']


class GramMatrix(Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
Expand Down
4 changes: 2 additions & 2 deletions experiments/recognition/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def main():
torch.cuda.manual_seed(args.seed)
# init dataloader
dataset = importlib.import_module('dataset.'+args.dataset)
Dataloder = dataset.Dataloder
train_loader, test_loader = Dataloder(args).getloader()
Dataloader = dataset.Dataloader
train_loader, test_loader = Dataloader(args).getloader()
# init the model
models = importlib.import_module('model.'+args.model)
model = models.Net(args)
Expand Down
21 changes: 21 additions & 0 deletions experiments/segmentation/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import encoding

# Get the model
model = encoding.models.get_model('fcn_resnet50_ade', pretrained=True).cuda()
model.eval()

# Prepare the image
url = 'https://github.com/zhanghang1989/image-data/blob/master/' + \
'encoding/segmentation/ade20k/ADE_val_00001142.jpg?raw=true'
filename = 'example.jpg'
img = encoding.utils.load_image(
encoding.utils.download(url, filename)).cuda().unsqueeze(0)

# Make prediction
output = model.evaluate(img)
predict = torch.max(output, 1)[1].cpu().numpy() + 1

# Get color pallete for visualization
mask = encoding.utils.get_mask_pallete(predict, 'ade20k')
mask.save('output.png')
6 changes: 3 additions & 3 deletions experiments/segmentation/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test(args):
# dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True} \
if args.cuda else {}
test_data = data.DataLoader(testset, batch_size=args.batch_size,
test_data = data.DataLoader(testset, batch_size=args.test_batch_size,
drop_last=False, shuffle=False,
collate_fn=test_batchify_fn, **kwargs)
# model
Expand Down Expand Up @@ -105,8 +105,8 @@ def eval_batch(image, dst, evaluator, eval_mode):
with torch.no_grad():
correct, labeled, inter, union = eval_batch(image, dst, evaluator, args.eval)
if args.eval:
total_correct += correct
total_label += labeled
total_correct += correct.astype('int64')
total_label += labeled.astype('int64')
total_inter += inter.astype('int64')
total_union += union.astype('int64')
pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
Expand Down
19 changes: 19 additions & 0 deletions experiments/segmentation/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import importlib
import torch
import encoding
from option import Options
from torch.autograd import Variable

if __name__ == "__main__":
args = Options().parse()
model = encoding.models.get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux,
se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d)
print('Creating the model:')

print(model)
model.cuda()
x = Variable(torch.Tensor(4, 3, 480, 480)).cuda()
with torch.no_grad():
out = model(x)
for y in out:
print(y.size())
27 changes: 15 additions & 12 deletions experiments/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,6 @@ def __init__(self, args):
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# resuming checkpoint
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
if not args.ft:
optimizer.load_state_dict(checkpoint['optimizer'])
best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
# clear start epoch if fine-tuning
if args.ft:
args.start_epoch = 0
Expand All @@ -82,6 +70,21 @@ def __init__(self, args):
if args.cuda:
self.model = DataParallelModel(self.model).cuda()
self.criterion = DataParallelCriterion(self.criterion).cuda()
# resuming checkpoint
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
if args.cuda:
self.model.module.load_state_dict(checkpoint['state_dict'])
else:
self.model.load_state_dict(checkpoint['state_dict'])
if not args.ft:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
# lr scheduler
self.scheduler = utils.LR_Scheduler(args, len(self.trainloader))
self.best_pred = 0.0
Expand Down
12 changes: 12 additions & 0 deletions scripts/prepare_pcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ def download_ade(path, overwrite=False):
else:
shutil.move(filename, os.path.join(path, 'VOCdevkit/VOC2010/'+os.path.basename(filename)))

def install_pcontext_api():
repo_url = "https://github.com/zhanghang1989/detail-api"
os.system("git clone " + repo_url)
os.system("cd detail-api/PythonAPI/ && python setup.py install")
shutil.rmtree('detail-api')
try:
import detail
except Exception:
print("Installing PASCAL Context API failed, please install it manually %s"%(repo_url))


if __name__ == '__main__':
args = parse_args()
mkdir(os.path.expanduser('~/.encoding/data'))
Expand All @@ -42,3 +53,4 @@ def download_ade(path, overwrite=False):
os.symlink(args.download_dir, _TARGET_DIR)
else:
download_ade(_TARGET_DIR, overwrite=False)
install_pcontext_api()

0 comments on commit 9bc7053

Please sign in to comment.