diff --git a/README.md b/README.md index dbe0aca4..dcd2134b 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ created by [Hang Zhang](http://hangzh.com/) - Please visit the [**Docs**](http://hangzh.com/PyTorch-Encoding/) for detail instructions of installation and usage. -- How to use Synchronized Batch Normalization (SyncBN)? See the [examples](https://github.com/zhanghang1989/PyTorch-SyncBatchNorm). +- Please visit the [link](http://hangzh.com/PyTorch-Encoding/experiments/segmentation.html) to examples of semantic segmentation. ## Citations diff --git a/encoding/datasets/base.py b/encoding/datasets/base.py index bd6b1290..6daf2e63 100644 --- a/encoding/datasets/base.py +++ b/encoding/datasets/base.py @@ -106,7 +106,4 @@ def test_batchify_fn(data): elif isinstance(data[0], (tuple, list)): data = zip(*data) return [test_batchify_fn(i) for i in data] - elif isinstance(data[0], ): - data = np.asarray(data) - return mx.nd.array(data, dtype=data.dtype) raise TypeError((error_msg.format(type(batch[0])))) diff --git a/encoding/models/base.py b/encoding/models/base.py index 0137c3e2..60624e23 100644 --- a/encoding/models/base.py +++ b/encoding/models/base.py @@ -20,7 +20,7 @@ up_kwargs = {'mode': 'bilinear', 'align_corners': True} -__all__ = ['BaseNet', 'EvalModule', 'MultiEvalModule'] +__all__ = ['BaseNet', 'MultiEvalModule'] class BaseNet(nn.Module): def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None, @@ -65,16 +65,6 @@ def evaluate(self, x, target=None): return correct, labeled, inter, union -class EvalModule(nn.Module): - """Segmentation Eval Module""" - def __init__(self, module): - super(EvalModule, self).__init__() - self.module = module - - def forward(self, *inputs, **kwargs): - return self.module.evaluate(*inputs, **kwargs) - - class MultiEvalModule(DataParallel): """Multi-size Segmentation Eavluator""" def __init__(self, module, nclass, device_ids=None, @@ -125,11 +115,11 @@ def forward(self, image): height = int(1.0 * h * long_size / w + 0.5) short_size = height # resize image to current size - cur_img = resize_image(image, height, width) - if scale <= 1.25 or long_size <= crop_size:# # + cur_img = resize_image(image, height, width, **self.module._up_kwargs) + if long_size <= crop_size: pad_img = pad_image(cur_img, self.module.mean, self.module.std, crop_size) - outputs = self.module_inference(pad_img) + outputs = module_inference(self.module, pad_img, self.flip) outputs = crop_image(outputs, 0, height, 0, width) else: if short_size < crop_size: @@ -157,7 +147,7 @@ def forward(self, image): # pad if needed pad_crop_img = pad_image(crop_img, self.module.mean, self.module.std, crop_size) - output = self.module_inference(pad_crop_img) + output = module_inference(self.module, pad_crop_img, self.flip) outputs[:,:,h0:h1,w0:w1] += crop_image(output, 0, h1-h0, 0, w1-w0) count_norm[:,:,h0:h1,w0:w1] += 1 @@ -165,21 +155,21 @@ def forward(self, image): outputs = outputs / count_norm outputs = outputs[:,:,:height,:width] - score = resize_image(outputs, h, w) + score = resize_image(outputs, h, w, **self.module._up_kwargs) scores += score return scores - def module_inference(self, image): - output = self.module.evaluate(image) - if self.flip: - fimg = flip_image(image) - foutput = self.module.evaluate(fimg) - output += flip_image(foutput) - return output.exp() +def module_inference(module, image, flip=True): + output = module.evaluate(image) + if flip: + fimg = flip_image(image) + foutput = module.evaluate(fimg) + output += flip_image(foutput) + return output.exp() -def resize_image(img, h, w, mode='bilinear'): +def resize_image(img, h, w, **up_kwargs): return F.upsample(img, (h, w), **up_kwargs) def pad_image(img, mean, std, crop_size): @@ -189,11 +179,9 @@ def pad_image(img, mean, std, crop_size): padw = crop_size - w if w < crop_size else 0 pad_values = -np.array(mean) / np.array(std) img_pad = img.new().resize_(b,c,h+padh,w+padw) - #img_pad = F.pad(img, (0,padw,0,padh)) for i in range(c): # note that pytorch pad params is in reversed orders - img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), - value=pad_values[i]) + img_pad[:,i,:,:] = F.pad(img[:,i,:,:], (0, padw, 0, padh), value=pad_values[i]) assert(img_pad.size(2)>=crop_size and img_pad.size(3)>=crop_size) return img_pad diff --git a/encoding/models/fcn.py b/encoding/models/fcn.py index 47e250ce..f3d310f4 100644 --- a/encoding/models/fcn.py +++ b/encoding/models/fcn.py @@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa >>> model = get_fcn_resnet50_pcontext(pretrained=True) >>> print(model) """ - return get_fcn('pcontext', 'resnet50', pretrained) + return get_fcn('pcontext', 'resnet50', pretrained, aux=False) def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" diff --git a/encoding/nn/customize.py b/encoding/nn/customize.py index 836c9a06..ceeb2457 100644 --- a/encoding/nn/customize.py +++ b/encoding/nn/customize.py @@ -21,7 +21,6 @@ __all__ = ['GramMatrix', 'SegmentationLosses', 'View', 'Sum', 'Mean', 'Normalize'] - class GramMatrix(Module): r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch diff --git a/experiments/recognition/main.py b/experiments/recognition/main.py index 43c5574f..43662203 100644 --- a/experiments/recognition/main.py +++ b/experiments/recognition/main.py @@ -45,8 +45,8 @@ def main(): torch.cuda.manual_seed(args.seed) # init dataloader dataset = importlib.import_module('dataset.'+args.dataset) - Dataloder = dataset.Dataloder - train_loader, test_loader = Dataloder(args).getloader() + Dataloader = dataset.Dataloader + train_loader, test_loader = Dataloader(args).getloader() # init the model models = importlib.import_module('model.'+args.model) model = models.Net(args) diff --git a/experiments/segmentation/demo.py b/experiments/segmentation/demo.py new file mode 100644 index 00000000..fde8848c --- /dev/null +++ b/experiments/segmentation/demo.py @@ -0,0 +1,21 @@ +import torch +import encoding + +# Get the model +model = encoding.models.get_model('fcn_resnet50_ade', pretrained=True).cuda() +model.eval() + +# Prepare the image +url = 'https://github.com/zhanghang1989/image-data/blob/master/' + \ + 'encoding/segmentation/ade20k/ADE_val_00001142.jpg?raw=true' +filename = 'example.jpg' +img = encoding.utils.load_image( + encoding.utils.download(url, filename)).cuda().unsqueeze(0) + +# Make prediction +output = model.evaluate(img) +predict = torch.max(output, 1)[1].cpu().numpy() + 1 + +# Get color pallete for visualization +mask = encoding.utils.get_mask_pallete(predict, 'ade20k') +mask.save('output.png') diff --git a/experiments/segmentation/test.py b/experiments/segmentation/test.py index 878a06e6..66deda8c 100644 --- a/experiments/segmentation/test.py +++ b/experiments/segmentation/test.py @@ -44,7 +44,7 @@ def test(args): # dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} \ if args.cuda else {} - test_data = data.DataLoader(testset, batch_size=args.batch_size, + test_data = data.DataLoader(testset, batch_size=args.test_batch_size, drop_last=False, shuffle=False, collate_fn=test_batchify_fn, **kwargs) # model @@ -105,8 +105,8 @@ def eval_batch(image, dst, evaluator, eval_mode): with torch.no_grad(): correct, labeled, inter, union = eval_batch(image, dst, evaluator, args.eval) if args.eval: - total_correct += correct - total_label += labeled + total_correct += correct.astype('int64') + total_label += labeled.astype('int64') total_inter += inter.astype('int64') total_union += union.astype('int64') pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label) diff --git a/experiments/segmentation/test_models.py b/experiments/segmentation/test_models.py new file mode 100644 index 00000000..92de0175 --- /dev/null +++ b/experiments/segmentation/test_models.py @@ -0,0 +1,19 @@ +import importlib +import torch +import encoding +from option import Options +from torch.autograd import Variable + +if __name__ == "__main__": + args = Options().parse() + model = encoding.models.get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux, + se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d) + print('Creating the model:') + + print(model) + model.cuda() + x = Variable(torch.Tensor(4, 3, 480, 480)).cuda() + with torch.no_grad(): + out = model(x) + for y in out: + print(y.size()) diff --git a/experiments/segmentation/train.py b/experiments/segmentation/train.py index d6328daf..739a7212 100644 --- a/experiments/segmentation/train.py +++ b/experiments/segmentation/train.py @@ -60,18 +60,6 @@ def __init__(self, args): lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - # resuming checkpoint - if args.resume is not None: - if not os.path.isfile(args.resume): - raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) - checkpoint = torch.load(args.resume) - args.start_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - if not args.ft: - optimizer.load_state_dict(checkpoint['optimizer']) - best_pred = checkpoint['best_pred'] - print("=> loaded checkpoint '{}' (epoch {})" - .format(args.resume, checkpoint['epoch'])) # clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 @@ -82,6 +70,21 @@ def __init__(self, args): if args.cuda: self.model = DataParallelModel(self.model).cuda() self.criterion = DataParallelCriterion(self.criterion).cuda() + # resuming checkpoint + if args.resume is not None: + if not os.path.isfile(args.resume): + raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + if args.cuda: + self.model.module.load_state_dict(checkpoint['state_dict']) + else: + self.model.load_state_dict(checkpoint['state_dict']) + if not args.ft: + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.best_pred = checkpoint['best_pred'] + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) # lr scheduler self.scheduler = utils.LR_Scheduler(args, len(self.trainloader)) self.best_pred = 0.0 diff --git a/scripts/prepare_pcontext.py b/scripts/prepare_pcontext.py index 857ab522..24814e4c 100644 --- a/scripts/prepare_pcontext.py +++ b/scripts/prepare_pcontext.py @@ -32,6 +32,17 @@ def download_ade(path, overwrite=False): else: shutil.move(filename, os.path.join(path, 'VOCdevkit/VOC2010/'+os.path.basename(filename))) +def install_pcontext_api(): + repo_url = "https://github.com/zhanghang1989/detail-api" + os.system("git clone " + repo_url) + os.system("cd detail-api/PythonAPI/ && python setup.py install") + shutil.rmtree('detail-api') + try: + import detail + except Exception: + print("Installing PASCAL Context API failed, please install it manually %s"%(repo_url)) + + if __name__ == '__main__': args = parse_args() mkdir(os.path.expanduser('~/.encoding/data')) @@ -42,3 +53,4 @@ def download_ade(path, overwrite=False): os.symlink(args.download_dir, _TARGET_DIR) else: download_ade(_TARGET_DIR, overwrite=False) + install_pcontext_api()