diff --git a/compare.py b/compare.py new file mode 100644 index 0000000..bcaf3e6 --- /dev/null +++ b/compare.py @@ -0,0 +1,145 @@ +import matplotlib.pyplot as plt +import os +import sys +import numpy as np +from collections import OrderedDict + +def savefig(fname, dpi=None): + dpi = 150 if dpi == None else dpi + plt.savefig(fname, dpi=dpi) + + +class Logger(object): + '''Save training process to log file with simple plot function.''' + def __init__(self, fpath, title=None, resume=False): + self.file = None + self.resume = resume + self.title = '' if title == None else title + if fpath is not None: + if resume: + self.file = open(fpath, 'r') + name = self.file.readline() + self.names = name.rstrip().split('\t') + self.numbers = {} + for _, name in enumerate(self.names): + self.numbers[name] = [] + + for numbers in self.file: + numbers = numbers.rstrip().split('\t') + for i in range(0, len(numbers)): + self.numbers[self.names[i]].append(numbers[i]) + self.file.close() + self.file = open(fpath, 'a') + else: + self.file = open(fpath, 'w') + + def set_names(self, names): + if self.resume: + pass + # initialize numbers as empty list + self.numbers = {} + self.names = names + for _, name in enumerate(self.names): + self.file.write(name) + self.file.write('\t') + self.numbers[name] = [] + self.file.write('\n') + self.file.flush() + + + def append(self, numbers): + assert len(self.names) == len(numbers), 'Numbers do not match names' + for index, num in enumerate(numbers): + self.file.write("{0:.6f}".format(num)) + self.file.write('\t') + self.numbers[self.names[index]].append(num) + self.file.write('\n') + self.file.flush() + + def plot(self, names=None): + names = self.names if names == None else names + numbers = self.numbers + for _, name in enumerate(names): + x = np.arange(len(numbers[name])) + plt.plot(x, np.asarray(numbers[name])) + plt.legend([self.title + '(' + name + ')' for name in names]) + plt.grid(True) + + def close(self): + if self.file is not None: + self.file.close() + +def plot_overlap(logger, names, axs, indexes): + names = logger.names if names == None else names + numbers = logger.numbers + for _, name in enumerate(names): + x = np.arange(len(numbers[name])) + axs[indexes[0],indexes[1]].plot(x, np.asarray(numbers[name])) + + return [logger.title for name in names] + +class LoggerMonitor(object): + '''Load and visualize multiple logs.''' + def __init__ (self, paths): + '''paths is a distionary with {name:filepath} pair''' + self.loggers = [] + for title, path in paths.items(): + logger = Logger(path, title=title, resume=True) + self.loggers.append(logger) + + def plot(self, names, axs, indexes): + + legend_text = [] + for logger in self.loggers: + legend_text += plot_overlap(logger, names, axs, indexes) + + axs[indexes[0],indexes[1]].set_title(names[0]) + return legend_text + +paths = OrderedDict() +fields = [['Train Loss'],['Valid Loss'], ['Train Acc.'],['Valid Acc.']] + +paths['SGD-Momentum']='momentum.txt' +I=float(1) +learning_rate = 0.01 + +name = 'pid.txt' +paths['PID']=name +fig, axs = plt.subplots(2, 2) + +i = 0 +indexes=[0,0] + +field = fields[i] +monitor = LoggerMonitor(paths) +legend_text = monitor.plot(names=field, axs=axs, indexes=indexes) +axs[indexes[0],indexes[1]].legend(legend_text, bbox_to_anchor=(1., 1.0), loc=1, borderaxespad=0.) +axs[indexes[0],indexes[1]].set_xlabel('Epoch') + +i = 1 +indexes=[0,1] +field = fields[i] +monitor = LoggerMonitor(paths) +legend_text = monitor.plot(names=field, axs=axs, indexes=indexes) +axs[indexes[0],indexes[1]].legend(legend_text, bbox_to_anchor=(1., 1.0), loc=1, borderaxespad=0.) +axs[indexes[0],indexes[1]].set_xlabel('Epoch') + +i = 2 +indexes=[1,0] +field = fields[i] +monitor = LoggerMonitor(paths) +legend_text = monitor.plot(names=field, axs=axs, indexes=indexes) +axs[indexes[0],indexes[1]].legend(legend_text, bbox_to_anchor=(1., 0.), loc=4, borderaxespad=0.) +axs[indexes[0],indexes[1]].set_xlabel('Epoch') + +i = 3 +indexes=[1,1] +field = fields[i] +monitor = LoggerMonitor(paths) +legend_text = monitor.plot(names=field, axs=axs, indexes=indexes) +axs[indexes[0],indexes[1]].legend(legend_text, bbox_to_anchor=(1., 0.), loc=4, borderaxespad=0.) +axs[indexes[0],indexes[1]].set_xlabel('Epoch') + +fig.tight_layout() + +savefig('moment_vs_pid.pdf') diff --git a/data/processed/test.pt b/data/processed/test.pt new file mode 100644 index 0000000..666c13a Binary files /dev/null and b/data/processed/test.pt differ diff --git a/data/processed/training.pt b/data/processed/training.pt new file mode 100644 index 0000000..bd6f623 Binary files /dev/null and b/data/processed/training.pt differ diff --git a/data/raw/t10k-images-idx3-ubyte b/data/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/data/raw/t10k-images-idx3-ubyte differ diff --git a/data/raw/t10k-labels-idx1-ubyte b/data/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/data/raw/t10k-labels-idx1-ubyte differ diff --git a/data/raw/train-images-idx3-ubyte b/data/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/data/raw/train-images-idx3-ubyte differ diff --git a/data/raw/train-labels-idx1-ubyte b/data/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/data/raw/train-labels-idx1-ubyte differ diff --git a/mnist_moment.py b/mnist_moment.py new file mode 100644 index 0000000..b7028bc --- /dev/null +++ b/mnist_moment.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torchvision.datasets as dsets +import torchvision.transforms as transforms +from torch.autograd import Variable +from torch.optim.sgd import SGD +import os +from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig +import torch.nn.functional as F +# Hyper Parameters +input_size = 784 +hidden_size = 1000 +num_classes = 10 +num_epochs = 20 +batch_size = 100 +learning_rate = 0.01 + +logger = Logger('momentum.txt', title='mnist') +logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) + +# MNIST Dataset +train_dataset = dsets.MNIST(root='./data', + train=True, + transform=transforms.ToTensor(), + download=True) + +test_dataset = dsets.MNIST(root='./data', + train=False, + transform=transforms.ToTensor()) + +# Data Loader (Input Pipeline) +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=batch_size, + shuffle=True) + +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, + batch_size=batch_size, + shuffle=False) + +# Neural Network Model (1 hidden layer) +class Net(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(Net, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = F.relu(out) + out = self.fc2(out) + return out + +net = Net(input_size, hidden_size, num_classes) +net.cuda() +net.train() +# Loss and Optimizer +criterion = nn.CrossEntropyLoss() +#optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) +optimizer = SGD(net.parameters(), lr=learning_rate, weight_decay=0.0001, momentum=0.9) +# Train the Model +for epoch in range(num_epochs): + + train_loss_log = AverageMeter() + train_acc_log = AverageMeter() + val_loss_log = AverageMeter() + val_acc_log = AverageMeter() + for i, (images, labels) in enumerate(train_loader): + # Convert torch tensor to Variable + images = Variable(images.view(-1, 28*28).cuda()) + labels = Variable(labels.cuda()) + + # Forward + Backward + Optimize + optimizer.zero_grad() # zero the gradient buffer + outputs = net(images) + train_loss = criterion(outputs, labels) + train_loss.backward() + optimizer.step() + prec1, prec5 = accuracy(outputs.data, labels.data, topk=(1, 5)) + train_loss_log.update(train_loss.data[0], images.size(0)) + train_acc_log.update(prec1[0], images.size(0)) + + if (i+1) % 100 == 0: + print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Acc: %.8f' + %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, train_loss_log.avg, train_acc_log.avg)) + + # Test the Model + net.eval() + correct = 0 + loss = 0 + total = 0 + for images, labels in test_loader: + images = Variable(images.view(-1, 28*28)).cuda() + labels = Variable(labels).cuda() + outputs = net(images) + test_loss = criterion(outputs, labels) + val_loss_log.update(test_loss.data[0], images.size(0)) + prec1, prec5 = accuracy(outputs.data, labels.data, topk=(1, 5)) + val_acc_log.update(prec1[0], images.size(0)) + + logger.append([learning_rate, train_loss_log.avg, val_loss_log.avg, train_acc_log.avg, val_acc_log.avg]) + print('Accuracy of the network on the 10000 test images: %.8f %%' % (val_acc_log.avg)) + print('Loss of the network on the 10000 test images: %.8f' % (val_loss_log.avg)) + +logger.close() +logger.plot() + diff --git a/mnist_pid.py b/mnist_pid.py new file mode 100644 index 0000000..427c2a8 --- /dev/null +++ b/mnist_pid.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torchvision.datasets as dsets +import torchvision.transforms as transforms +from torch.autograd import Variable +from pid import PIDOptimizer +import os +import numpy as np +from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig +import torch.nn.functional as F +# Hyper Parameters +input_size = 784 +hidden_size = 1000 +num_classes = 10 +num_epochs = 20 +batch_size = 100 +learning_rate = 0.01 + +I=1 +I = float(I) +D = 200 +D = float(D) + + +logger = Logger('pid.txt', title='mnist') +logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) + +# MNIST Dataset +train_dataset = dsets.MNIST(root='./data', + train=True, + transform=transforms.ToTensor(), + download=True) + +test_dataset = dsets.MNIST(root='./data', + train=False, + transform=transforms.ToTensor()) + +# Data Loader (Input Pipeline) +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=batch_size, + shuffle=True) + +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, + batch_size=batch_size, + shuffle=False) + +# Neural Network Model (1 hidden layer) +class Net(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(Net, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = F.relu(out) + out = self.fc2(out) + return out + +net = Net(input_size, hidden_size, num_classes) +net.cuda() +net.train() +# Loss and Optimizer +criterion = nn.CrossEntropyLoss() +optimizer = PIDOptimizer(net.parameters(), lr=learning_rate, weight_decay=0.0001, momentum=0.9, I=I, D=D) +# Train the Model +for epoch in range(num_epochs): + + train_loss_log = AverageMeter() + train_acc_log = AverageMeter() + val_loss_log = AverageMeter() + val_acc_log = AverageMeter() + for i, (images, labels) in enumerate(train_loader): + # Convert torch tensor to Variable + images = Variable(images.view(-1, 28*28).cuda()) + labels = Variable(labels.cuda()) + + # Forward + Backward + Optimize + optimizer.zero_grad() # zero the gradient buffer + outputs = net(images) + train_loss = criterion(outputs, labels) + train_loss.backward() + optimizer.step() + prec1, prec5 = accuracy(outputs.data, labels.data, topk=(1, 5)) + train_loss_log.update(train_loss.data[0], images.size(0)) + train_acc_log.update(prec1[0], images.size(0)) + + if (i+1) % 100 == 0: + print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Acc: %.8f' + %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, train_loss_log.avg, train_acc_log.avg)) + + # Test the Model + net.eval() + correct = 0 + loss = 0 + total = 0 + for images, labels in test_loader: + images = Variable(images.view(-1, 28*28)).cuda() + labels = Variable(labels).cuda() + outputs = net(images) + test_loss = criterion(outputs, labels) + val_loss_log.update(test_loss.data[0], images.size(0)) + prec1, prec5 = accuracy(outputs.data, labels.data, topk=(1, 5)) + val_acc_log.update(prec1[0], images.size(0)) + + logger.append([learning_rate, train_loss_log.avg, val_loss_log.avg, train_acc_log.avg, val_acc_log.avg]) + print('Accuracy of the network on the 10000 test images: %d %%' % (val_acc_log.avg)) + print('Loss of the network on the 10000 test images: %.8f' % (val_loss_log.avg)) + +logger.close() +logger.plot() + diff --git a/moment_vs_pid.pdf b/moment_vs_pid.pdf new file mode 100644 index 0000000..0b8a05a Binary files /dev/null and b/moment_vs_pid.pdf differ diff --git a/momentum.txt b/momentum.txt new file mode 100644 index 0000000..d6a0450 --- /dev/null +++ b/momentum.txt @@ -0,0 +1,21 @@ +Learning Rate Train Loss Valid Loss Train Acc. Valid Acc. +0.010000 0.516547 0.289208 86.581667 91.600000 +0.010000 0.259415 0.216152 92.545000 93.950000 +0.010000 0.201986 0.177764 94.343333 94.780000 +0.010000 0.165051 0.149667 95.318333 95.600000 +0.010000 0.139030 0.128418 96.070000 96.220000 +0.010000 0.119939 0.116352 96.710000 96.600000 +0.010000 0.104811 0.104972 97.101667 96.880000 +0.010000 0.093801 0.096976 97.466667 97.120000 +0.010000 0.084180 0.094102 97.731667 97.140000 +0.010000 0.076191 0.085866 97.956667 97.510000 +0.010000 0.069832 0.081906 98.118333 97.590000 +0.010000 0.064328 0.079128 98.340000 97.710000 +0.010000 0.059059 0.077403 98.468333 97.740000 +0.010000 0.055209 0.074700 98.570000 97.780000 +0.010000 0.051301 0.072494 98.680000 97.880000 +0.010000 0.047872 0.071196 98.815000 97.770000 +0.010000 0.044999 0.071158 98.900000 97.790000 +0.010000 0.042177 0.066217 98.998333 97.990000 +0.010000 0.039857 0.066495 99.031667 97.970000 +0.010000 0.037694 0.065470 99.106667 97.950000 diff --git a/pid.txt b/pid.txt new file mode 100644 index 0000000..fbcc24a --- /dev/null +++ b/pid.txt @@ -0,0 +1,21 @@ +Learning Rate Train Loss Valid Loss Train Acc. Valid Acc. +0.010000 0.352578 0.173396 89.950000 95.000000 +0.010000 0.147856 0.116742 95.783333 96.680000 +0.010000 0.100508 0.088020 97.140000 97.380000 +0.010000 0.076910 0.079098 97.785000 97.540000 +0.010000 0.062036 0.076330 98.281667 97.700000 +0.010000 0.052133 0.068628 98.568333 97.940000 +0.010000 0.044094 0.064560 98.823333 97.970000 +0.010000 0.037067 0.063429 99.020000 98.080000 +0.010000 0.032099 0.061729 99.188333 97.950000 +0.010000 0.028184 0.059942 99.316667 98.050000 +0.010000 0.024805 0.057361 99.466667 98.100000 +0.010000 0.022183 0.058508 99.556667 98.140000 +0.010000 0.019601 0.054335 99.636667 98.130000 +0.010000 0.017598 0.057352 99.720000 98.110000 +0.010000 0.016092 0.055072 99.766667 98.140000 +0.010000 0.014578 0.055409 99.840000 98.200000 +0.010000 0.013609 0.054895 99.841667 98.140000 +0.010000 0.012894 0.053356 99.853333 98.220000 +0.010000 0.011809 0.053756 99.878333 98.210000 +0.010000 0.011457 0.052060 99.896667 98.260000 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..848436b --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,11 @@ +"""Useful utils +""" +from .misc import * +from .logger import * +from .visualize import * +from .eval import * + +# progress bar +import os, sys +sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) +from progress.bar import Bar as Bar \ No newline at end of file diff --git a/utils/eval.py b/utils/eval.py new file mode 100644 index 0000000..5051350 --- /dev/null +++ b/utils/eval.py @@ -0,0 +1,18 @@ +from __future__ import print_function, absolute_import + +__all__ = ['accuracy'] + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..cffeb9d --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,129 @@ +# A simple torch style logger +# (C) Wei YANG 2017 +from __future__ import absolute_import +import matplotlib.pyplot as plt +import os +import sys +import numpy as np + +__all__ = ['Logger', 'LoggerMonitor', 'savefig'] + +def savefig(fname, dpi=None): + dpi = 150 if dpi == None else dpi + plt.savefig(fname, dpi=dpi) + +def plot_overlap(logger, names=None): + names = logger.names if names == None else names + numbers = logger.numbers + for _, name in enumerate(names): + x = np.arange(len(numbers[name])) + plt.plot(x, np.asarray(numbers[name])) + + return [logger.title for name in names] +# return [logger.title + '(' + name + ')' for name in names] + +class Logger(object): + '''Save training process to log file with simple plot function.''' + def __init__(self, fpath, title=None, resume=False): + self.file = None + self.resume = resume + self.title = '' if title == None else title + if fpath is not None: + if resume: + self.file = open(fpath, 'r') + name = self.file.readline() + self.names = name.rstrip().split('\t') + self.numbers = {} + for _, name in enumerate(self.names): + self.numbers[name] = [] + + for numbers in self.file: + numbers = numbers.rstrip().split('\t') + for i in range(0, len(numbers)): + self.numbers[self.names[i]].append(numbers[i]) + self.file.close() + self.file = open(fpath, 'a') + else: + self.file = open(fpath, 'w') + + def set_names(self, names): + if self.resume: + pass + # initialize numbers as empty list + self.numbers = {} + self.names = names + for _, name in enumerate(self.names): + self.file.write(name) + self.file.write('\t') + self.numbers[name] = [] + self.file.write('\n') + self.file.flush() + + + def append(self, numbers): + assert len(self.names) == len(numbers), 'Numbers do not match names' + for index, num in enumerate(numbers): + self.file.write("{0:.6f}".format(num)) + self.file.write('\t') + self.numbers[self.names[index]].append(num) + self.file.write('\n') + self.file.flush() + + def plot(self, names=None): + names = self.names if names == None else names + numbers = self.numbers + for _, name in enumerate(names): + x = np.arange(len(numbers[name])) + plt.plot(x, np.asarray(numbers[name])) + plt.legend([self.title + '(' + name + ')' for name in names]) + plt.grid(True) + + def close(self): + if self.file is not None: + self.file.close() + +class LoggerMonitor(object): + '''Load and visualize multiple logs.''' + def __init__ (self, paths): + '''paths is a distionary with {name:filepath} pair''' + self.loggers = [] + for title, path in paths.items(): + logger = Logger(path, title=title, resume=True) + self.loggers.append(logger) + + def plot(self, names=None): + plt.figure() + plt.subplot(121) + legend_text = [] + for logger in self.loggers: + legend_text += plot_overlap(logger, names) + plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) + plt.grid(True) + +if __name__ == '__main__': + # # Example + # logger = Logger('test.txt') + # logger.set_names(['Train loss', 'Valid loss','Test loss']) + + # length = 100 + # t = np.arange(length) + # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 + # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 + # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 + + # for i in range(0, length): + # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) + # logger.plot() + + # Example: logger monitor + paths = { + 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', + 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', + 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', + } + + field = ['Valid Acc.'] + + monitor = LoggerMonitor(paths) + monitor.plot(names=field) + savefig('test.eps') diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..d387f59 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,76 @@ +'''Some helper functions for PyTorch, including: + - get_mean_and_std: calculate the mean and std value of dataset. + - msr_init: net parameter initialization. + - progress_bar: progress bar mimic xlua.progress. +''' +import errno +import os +import sys +import time +import math + +import torch.nn as nn +import torch.nn.init as init +from torch.autograd import Variable + +__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] + + +def get_mean_and_std(dataset): + '''Compute the mean and std value of dataset.''' + dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) + + mean = torch.zeros(3) + std = torch.zeros(3) + print('==> Computing mean and std..') + for inputs, targets in dataloader: + for i in range(3): + mean[i] += inputs[:,i,:,:].mean() + std[i] += inputs[:,i,:,:].std() + mean.div_(len(dataset)) + std.div_(len(dataset)) + return mean, std + +def init_params(net): + '''Init layer parameters.''' + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal(m.weight, mode='fan_out') + if m.bias: + init.constant(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + init.constant(m.weight, 1) + init.constant(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal(m.weight, std=1e-3) + if m.bias: + init.constant(m.bias, 0) + +def mkdir_p(path): + '''make dir if not exist''' + try: + os.makedirs(path) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + +class AverageMeter(object): + """Computes and stores the average and current value + Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count \ No newline at end of file diff --git a/utils/progress/.gitignore b/utils/progress/.gitignore new file mode 100644 index 0000000..488ddd5 --- /dev/null +++ b/utils/progress/.gitignore @@ -0,0 +1,4 @@ +*.pyc +*.egg-info +build/ +dist/ diff --git a/utils/progress/LICENSE b/utils/progress/LICENSE new file mode 100644 index 0000000..059cc05 --- /dev/null +++ b/utils/progress/LICENSE @@ -0,0 +1,13 @@ +# Copyright (c) 2012 Giorgos Verigakis +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/utils/progress/MANIFEST.in b/utils/progress/MANIFEST.in new file mode 100644 index 0000000..0c73842 --- /dev/null +++ b/utils/progress/MANIFEST.in @@ -0,0 +1 @@ +include README.rst LICENSE diff --git a/utils/progress/README.rst b/utils/progress/README.rst new file mode 100644 index 0000000..3f3be76 --- /dev/null +++ b/utils/progress/README.rst @@ -0,0 +1,131 @@ +Easy progress reporting for Python +================================== + +|pypi| + +|demo| + +.. |pypi| image:: https://img.shields.io/pypi/v/progress.svg +.. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif + :alt: Demo + +Bars +---- + +There are 7 progress bars to choose from: + +- ``Bar`` +- ``ChargingBar`` +- ``FillingSquaresBar`` +- ``FillingCirclesBar`` +- ``IncrementalBar`` +- ``PixelBar`` +- ``ShadyBar`` + +To use them, just call ``next`` to advance and ``finish`` to finish: + +.. code-block:: python + + from progress.bar import Bar + + bar = Bar('Processing', max=20) + for i in range(20): + # Do some work + bar.next() + bar.finish() + +The result will be a bar like the following: :: + + Processing |############# | 42/100 + +To simplify the common case where the work is done in an iterator, you can +use the ``iter`` method: + +.. code-block:: python + + for i in Bar('Processing').iter(it): + # Do some work + +Progress bars are very customizable, you can change their width, their fill +character, their suffix and more: + +.. code-block:: python + + bar = Bar('Loading', fill='@', suffix='%(percent)d%%') + +This will produce a bar like the following: :: + + Loading |@@@@@@@@@@@@@ | 42% + +You can use a number of template arguments in ``message`` and ``suffix``: + +========== ================================ +Name Value +========== ================================ +index current value +max maximum value +remaining max - index +progress index / max +percent progress * 100 +avg simple moving average time per item (in seconds) +elapsed elapsed time in seconds +elapsed_td elapsed as a timedelta (useful for printing as a string) +eta avg * remaining +eta_td eta as a timedelta (useful for printing as a string) +========== ================================ + +Instead of passing all configuration options on instatiation, you can create +your custom subclass: + +.. code-block:: python + + class FancyBar(Bar): + message = 'Loading' + fill = '*' + suffix = '%(percent).1f%% - %(eta)ds' + +You can also override any of the arguments or create your own: + +.. code-block:: python + + class SlowBar(Bar): + suffix = '%(remaining_hours)d hours remaining' + @property + def remaining_hours(self): + return self.eta // 3600 + + +Spinners +======== + +For actions with an unknown number of steps you can use a spinner: + +.. code-block:: python + + from progress.spinner import Spinner + + spinner = Spinner('Loading ') + while state != 'FINISHED': + # Do some work + spinner.next() + +There are 5 predefined spinners: + +- ``Spinner`` +- ``PieSpinner`` +- ``MoonSpinner`` +- ``LineSpinner`` +- ``PixelSpinner`` + + +Other +===== + +There are a number of other classes available too, please check the source or +subclass one of them to create your own. + + +License +======= + +progress is licensed under ISC diff --git a/utils/progress/demo.gif b/utils/progress/demo.gif new file mode 100644 index 0000000..64b1e95 Binary files /dev/null and b/utils/progress/demo.gif differ diff --git a/utils/progress/progress/__init__.py b/utils/progress/progress/__init__.py new file mode 100644 index 0000000..09dfc1e --- /dev/null +++ b/utils/progress/progress/__init__.py @@ -0,0 +1,127 @@ +# Copyright (c) 2012 Giorgos Verigakis +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +from __future__ import division + +from collections import deque +from datetime import timedelta +from math import ceil +from sys import stderr +from time import time + + +__version__ = '1.3' + + +class Infinite(object): + file = stderr + sma_window = 10 # Simple Moving Average window + + def __init__(self, *args, **kwargs): + self.index = 0 + self.start_ts = time() + self.avg = 0 + self._ts = self.start_ts + self._xput = deque(maxlen=self.sma_window) + for key, val in kwargs.items(): + setattr(self, key, val) + + def __getitem__(self, key): + if key.startswith('_'): + return None + return getattr(self, key, None) + + @property + def elapsed(self): + return int(time() - self.start_ts) + + @property + def elapsed_td(self): + return timedelta(seconds=self.elapsed) + + def update_avg(self, n, dt): + if n > 0: + self._xput.append(dt / n) + self.avg = sum(self._xput) / len(self._xput) + + def update(self): + pass + + def start(self): + pass + + def finish(self): + pass + + def next(self, n=1): + now = time() + dt = now - self._ts + self.update_avg(n, dt) + self._ts = now + self.index = self.index + n + self.update() + + def iter(self, it): + try: + for x in it: + yield x + self.next() + finally: + self.finish() + + +class Progress(Infinite): + def __init__(self, *args, **kwargs): + super(Progress, self).__init__(*args, **kwargs) + self.max = kwargs.get('max', 100) + + @property + def eta(self): + return int(ceil(self.avg * self.remaining)) + + @property + def eta_td(self): + return timedelta(seconds=self.eta) + + @property + def percent(self): + return self.progress * 100 + + @property + def progress(self): + return min(1, self.index / self.max) + + @property + def remaining(self): + return max(self.max - self.index, 0) + + def start(self): + self.update() + + def goto(self, index): + incr = index - self.index + self.next(incr) + + def iter(self, it): + try: + self.max = len(it) + except TypeError: + pass + + try: + for x in it: + yield x + self.next() + finally: + self.finish() diff --git a/utils/progress/progress/bar.py b/utils/progress/progress/bar.py new file mode 100644 index 0000000..5ee968f --- /dev/null +++ b/utils/progress/progress/bar.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2012 Giorgos Verigakis +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +from __future__ import unicode_literals +from . import Progress +from .helpers import WritelnMixin + + +class Bar(WritelnMixin, Progress): + width = 32 + message = '' + suffix = '%(index)d/%(max)d' + bar_prefix = ' |' + bar_suffix = '| ' + empty_fill = ' ' + fill = '#' + hide_cursor = True + + def update(self): + filled_length = int(self.width * self.progress) + empty_length = self.width - filled_length + + message = self.message % self + bar = self.fill * filled_length + empty = self.empty_fill * empty_length + suffix = self.suffix % self + line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, + suffix]) + self.writeln(line) + + +class ChargingBar(Bar): + suffix = '%(percent)d%%' + bar_prefix = ' ' + bar_suffix = ' ' + empty_fill = '∙' + fill = '█' + + +class FillingSquaresBar(ChargingBar): + empty_fill = '▢' + fill = '▣' + + +class FillingCirclesBar(ChargingBar): + empty_fill = '◯' + fill = '◉' + + +class IncrementalBar(Bar): + phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') + + def update(self): + nphases = len(self.phases) + filled_len = self.width * self.progress + nfull = int(filled_len) # Number of full chars + phase = int((filled_len - nfull) * nphases) # Phase of last char + nempty = self.width - nfull # Number of empty chars + + message = self.message % self + bar = self.phases[-1] * nfull + current = self.phases[phase] if phase > 0 else '' + empty = self.empty_fill * max(0, nempty - len(current)) + suffix = self.suffix % self + line = ''.join([message, self.bar_prefix, bar, current, empty, + self.bar_suffix, suffix]) + self.writeln(line) + + +class PixelBar(IncrementalBar): + phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') + + +class ShadyBar(IncrementalBar): + phases = (' ', '░', '▒', '▓', '█') diff --git a/utils/progress/progress/counter.py b/utils/progress/progress/counter.py new file mode 100644 index 0000000..6b45a1e --- /dev/null +++ b/utils/progress/progress/counter.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2012 Giorgos Verigakis +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +from __future__ import unicode_literals +from . import Infinite, Progress +from .helpers import WriteMixin + + +class Counter(WriteMixin, Infinite): + message = '' + hide_cursor = True + + def update(self): + self.write(str(self.index)) + + +class Countdown(WriteMixin, Progress): + hide_cursor = True + + def update(self): + self.write(str(self.remaining)) + + +class Stack(WriteMixin, Progress): + phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') + hide_cursor = True + + def update(self): + nphases = len(self.phases) + i = min(nphases - 1, int(self.progress * nphases)) + self.write(self.phases[i]) + + +class Pie(Stack): + phases = ('○', '◔', '◑', '◕', '●') diff --git a/utils/progress/progress/helpers.py b/utils/progress/progress/helpers.py new file mode 100644 index 0000000..9ed90b2 --- /dev/null +++ b/utils/progress/progress/helpers.py @@ -0,0 +1,91 @@ +# Copyright (c) 2012 Giorgos Verigakis +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +from __future__ import print_function + + +HIDE_CURSOR = '\x1b[?25l' +SHOW_CURSOR = '\x1b[?25h' + + +class WriteMixin(object): + hide_cursor = False + + def __init__(self, message=None, **kwargs): + super(WriteMixin, self).__init__(**kwargs) + self._width = 0 + if message: + self.message = message + + if self.file.isatty(): + if self.hide_cursor: + print(HIDE_CURSOR, end='', file=self.file) + print(self.message, end='', file=self.file) + self.file.flush() + + def write(self, s): + if self.file.isatty(): + b = '\b' * self._width + c = s.ljust(self._width) + print(b + c, end='', file=self.file) + self._width = max(self._width, len(s)) + self.file.flush() + + def finish(self): + if self.file.isatty() and self.hide_cursor: + print(SHOW_CURSOR, end='', file=self.file) + + +class WritelnMixin(object): + hide_cursor = False + + def __init__(self, message=None, **kwargs): + super(WritelnMixin, self).__init__(**kwargs) + if message: + self.message = message + + if self.file.isatty() and self.hide_cursor: + print(HIDE_CURSOR, end='', file=self.file) + + def clearln(self): + if self.file.isatty(): + print('\r\x1b[K', end='', file=self.file) + + def writeln(self, line): + if self.file.isatty(): + self.clearln() + print(line, end='', file=self.file) + self.file.flush() + + def finish(self): + if self.file.isatty(): + print(file=self.file) + if self.hide_cursor: + print(SHOW_CURSOR, end='', file=self.file) + + +from signal import signal, SIGINT +from sys import exit + + +class SigIntMixin(object): + """Registers a signal handler that calls finish on SIGINT""" + + def __init__(self, *args, **kwargs): + super(SigIntMixin, self).__init__(*args, **kwargs) + signal(SIGINT, self._sigint_handler) + + def _sigint_handler(self, signum, frame): + self.finish() + exit(0) diff --git a/utils/progress/progress/spinner.py b/utils/progress/progress/spinner.py new file mode 100644 index 0000000..464c7b2 --- /dev/null +++ b/utils/progress/progress/spinner.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2012 Giorgos Verigakis +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +from __future__ import unicode_literals +from . import Infinite +from .helpers import WriteMixin + + +class Spinner(WriteMixin, Infinite): + message = '' + phases = ('-', '\\', '|', '/') + hide_cursor = True + + def update(self): + i = self.index % len(self.phases) + self.write(self.phases[i]) + + +class PieSpinner(Spinner): + phases = ['◷', '◶', '◵', '◴'] + + +class MoonSpinner(Spinner): + phases = ['◑', '◒', '◐', '◓'] + + +class LineSpinner(Spinner): + phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] + +class PixelSpinner(Spinner): + phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] diff --git a/utils/progress/setup.py b/utils/progress/setup.py new file mode 100755 index 0000000..c877781 --- /dev/null +++ b/utils/progress/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +from setuptools import setup + +import progress + + +setup( + name='progress', + version=progress.__version__, + description='Easy to use progress bars', + long_description=open('README.rst').read(), + author='Giorgos Verigakis', + author_email='verigak@gmail.com', + url='http://github.com/verigak/progress/', + license='ISC', + packages=['progress'], + classifiers=[ + 'Environment :: Console', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: ISC License (ISCL)', + 'Programming Language :: Python :: 2.6', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + ] +) diff --git a/utils/progress/test_progress.py b/utils/progress/test_progress.py new file mode 100755 index 0000000..0f68b01 --- /dev/null +++ b/utils/progress/test_progress.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +from __future__ import print_function + +import random +import time + +from progress.bar import (Bar, ChargingBar, FillingSquaresBar, + FillingCirclesBar, IncrementalBar, PixelBar, + ShadyBar) +from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, + PixelSpinner) +from progress.counter import Counter, Countdown, Stack, Pie + + +def sleep(): + t = 0.01 + t += t * random.uniform(-0.1, 0.1) # Add some variance + time.sleep(t) + + +for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): + suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' + bar = bar_cls(bar_cls.__name__, suffix=suffix) + for i in bar.iter(range(200)): + sleep() + +for bar_cls in (IncrementalBar, PixelBar, ShadyBar): + suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' + bar = bar_cls(bar_cls.__name__, suffix=suffix) + for i in bar.iter(range(200)): + sleep() + +for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): + for i in spin(spin.__name__ + ' ').iter(range(100)): + sleep() + print() + +for singleton in (Counter, Countdown, Stack, Pie): + for i in singleton(singleton.__name__ + ' ').iter(range(100)): + sleep() + print() + +bar = IncrementalBar('Random', suffix='%(index)d') +for i in range(100): + bar.goto(random.randint(0, 100)) + sleep() +bar.finish() diff --git a/utils/visualize.py b/utils/visualize.py new file mode 100644 index 0000000..51abeed --- /dev/null +++ b/utils/visualize.py @@ -0,0 +1,110 @@ +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +import numpy as np +from .misc import * + +__all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] + +# functions to show an image +def make_image(img, mean=(0,0,0), std=(1,1,1)): + for i in range(0, 3): + img[i] = img[i] * std[i] + mean[i] # unnormalize + npimg = img.numpy() + return np.transpose(npimg, (1, 2, 0)) + +def gauss(x,a,b,c): + return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) + +def colorize(x): + ''' Converts a one-channel grayscale image to a color heatmap image ''' + if x.dim() == 2: + torch.unsqueeze(x, 0, out=x) + if x.dim() == 3: + cl = torch.zeros([3, x.size(1), x.size(2)]) + cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) + cl[1] = gauss(x,1,.5,.3) + cl[2] = gauss(x,1,.2,.3) + cl[cl.gt(1)] = 1 + elif x.dim() == 4: + cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) + cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) + cl[:,1,:,:] = gauss(x,1,.5,.3) + cl[:,2,:,:] = gauss(x,1,.2,.3) + return cl + +def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): + images = make_image(torchvision.utils.make_grid(images), Mean, Std) + plt.imshow(images) + plt.show() + + +def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): + im_size = images.size(2) + + # save for adding mask + im_data = images.clone() + for i in range(0, 3): + im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize + + images = make_image(torchvision.utils.make_grid(images), Mean, Std) + plt.subplot(2, 1, 1) + plt.imshow(images) + plt.axis('off') + + # for b in range(mask.size(0)): + # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) + mask_size = mask.size(2) + # print('Max %f Min %f' % (mask.max(), mask.min())) + mask = (upsampling(mask, scale_factor=im_size/mask_size)) + # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) + # for c in range(3): + # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] + + # print(mask.size()) + mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) + # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) + plt.subplot(2, 1, 2) + plt.imshow(mask) + plt.axis('off') + +def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): + im_size = images.size(2) + + # save for adding mask + im_data = images.clone() + for i in range(0, 3): + im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize + + images = make_image(torchvision.utils.make_grid(images), Mean, Std) + plt.subplot(1+len(masklist), 1, 1) + plt.imshow(images) + plt.axis('off') + + for i in range(len(masklist)): + mask = masklist[i].data.cpu() + # for b in range(mask.size(0)): + # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) + mask_size = mask.size(2) + # print('Max %f Min %f' % (mask.max(), mask.min())) + mask = (upsampling(mask, scale_factor=im_size/mask_size)) + # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) + # for c in range(3): + # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] + + # print(mask.size()) + mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) + # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) + plt.subplot(1+len(masklist), 1, i+2) + plt.imshow(mask) + plt.axis('off') + + + +# x = torch.zeros(1, 3, 3) +# out = colorize(x) +# out_im = make_image(out) +# plt.imshow(out_im) +# plt.show() \ No newline at end of file