Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support KLD metric and support evaluation for probabilistic models #108

Merged
merged 6 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support KLD metric
  • Loading branch information
LeoXing1996 committed Sep 7, 2021
commit f479fdc026155e940f72f192fb2783767d3a347f
6 changes: 4 additions & 2 deletions mmgen/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .evaluation import (make_metrics_table, make_vanilla_dataloader,
single_gpu_evaluation, single_gpu_online_evaluation)
from .metric_utils import slerp
from .metrics import IS, MS_SSIM, PR, SWD, ms_ssim, sliced_wasserstein
from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, gaussian_kld, ms_ssim,
sliced_wasserstein)

__all__ = [
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'single_gpu_evaluation',
'single_gpu_online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader'
'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD',
'gaussian_kld'
]
41 changes: 40 additions & 1 deletion mmgen/core/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,26 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
of the metric table include training configuration and ckpt.
kwargs (dict): Other arguments.
"""
# separate metrics into special metrics and vanilla metrics.
# separate metrics into special metrics, probabilistic metrics and vanilla
# metrics.
# For vanilla metrics, images are generated in a random way, and are
# shared by these metrics. For special metrics like 'PPL', images are
# generated in a metric-special way and not shared between different
# metrics.
# For probabilistic metrics like 'GaussianKLD', they do not
# receive images but receive a dict with cooresponding probabilistic
# parameter. To make the model return probabilistic

special_metrics = []
probabilistic_metric = []
vanilla_metrics = []
special_metric_name = ['PPL']
probabilistic_metric_name = ['GaussianKLD']
for metric in metrics:
if metric.name in special_metric_name:
special_metrics.append(metric)
elif metric.name in probabilistic_metric_name:
probabilistic_metric.append(metric)
else:
vanilla_metrics.append(metric)

Expand Down Expand Up @@ -329,6 +338,36 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
# finish the pbar stdout
sys.stdout.write('\n')

# feed probabilistic metric
for metric in probabilistic_metric:
metric.prepare()
pbar = mmcv.ProgressBar(len(data_loader))
# here we assert probabilistic model have reconstruction mode
LeoXing1996 marked this conversation as resolved.
Show resolved Hide resolved
kwargs['mode'] = 'reconstruction'
for data in data_loader:
# key for unconditional GAN
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')

if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
LeoXing1996 marked this conversation as resolved.
Show resolved Hide resolved

prob_dict = model(reals, return_loss=False, **kwargs)
num_feed = metric.feed(prob_dict, 'reals')
pbar.update(num_feed)

for metric in metrics:
metric.summary()

Expand Down
180 changes: 172 additions & 8 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,46 @@ def sliced_wasserstein(distribution_a,
return sum(results) / dir_repeats


def gaussian_kld(mean_1, mean_2, logvar_1, logvar_2, base='e'):
r"""Calculate KLD (Kullback-Leibler divergence) of two gaussian
distribution.
To be noted that in this function, KLD is calcuated in base `e`.

.. math::
:nowarp:
\begin{eqnarray}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} - \frac{1}{2}
\end{eqnarray}

Args:
mean_1 (torch.Tensor): Mean of the first (or the target) distribution.
mean_2 (torch.Tensor): Mean of the second (or the predicted)
distribution.
logvar_1 (torch.Tensor): Log variance of the first (or the target)
distribution
logvar_2 (torch.Tensor): Log variance of the second (or the predicted)
distribution.
base (str, optional): The log base of calculated KLD. Support ``'e'``
and ``'2'``. Defaults to ``'e'``.

Returns:
torch.Tensor: KLD between two given distribution.
"""
if base not in ['e', '2']:
raise ValueError('Only support 2 and e for log base, but receive '
f'{base}')
kld = 0.5 * (-1.0 + logvar_2 - logvar_1 + torch.exp(logvar_1 - logvar_2) +
((mean_1 - mean_2)**2) * torch.exp(-logvar_2))
if base == '2':
return kld / np.log(2)
return kld


class Metric(ABC):
"""The abstract base class of metrics. Basically, we split calculation into
three steps. First, we initialize the metric object and do some
Expand Down Expand Up @@ -383,18 +423,28 @@ def feed(self, batch, mode):
operation in 'feed_op' function.

Args:
batch (Tensor): Images feeded into metric object with order "NCHW"
and range [-1, 1].
batch (Tensor|dict[Tensor]): Images or dict to be feeded into
LeoXing1996 marked this conversation as resolved.
Show resolved Hide resolved
metric object. If ``Tensor`` is passed, the order of ``Tensor``
should be "NCHW". If ``dict`` is passed, each term in the
``dict`` are ``Tensor`` with order "NCHW".
mode (str): Mark the batch as real or fake images. Value can be
'reals' or 'fakes',
"""
if mode == 'reals':
if self.num_real_feeded == self.num_real_need:
return 0

batch_size = batch.shape[0]
end = min(batch_size, self.num_real_need - self.num_real_feeded)
self.feed_op(batch[:end, :, :, :], mode)
if isinstance(batch, dict):
batch_size = [v for v in batch.values()][0].shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_size = batch.shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = batch[:end, ...]
self.feed_op(batch_to_feed, mode)
self.num_real_feeded += end
return end

Expand All @@ -404,13 +454,17 @@ def feed(self, batch, mode):

batch_size = batch.shape[0]
end = min(batch_size, self.num_fake_need - self.num_fake_feeded)
self.feed_op(batch[:end, :, :, :], mode)
if isinstance(batch, dict):
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_to_feed = batch[:end, ...]
self.feed_op(batch_to_feed, mode)
self.num_fake_feeded += end
return end
else:
raise ValueError(
f"The expected mode should be set to 'reals' or 'fakes,\
but got '{mode}'")
'The expected mode should be set to \'reals\' or \'fakes\','
f'but got \'{mode}\'')

def check(self):
"""Check the numbers of image."""
Expand Down Expand Up @@ -1343,3 +1397,113 @@ def __next__(self):

self.idx += 1
return image


class GaussianKLD(Metric):
r"""Gaussian KLD (Kullback-Leibler divergence) metric. We calculate the
KLD between two gaussian distribution via `mean` and `log_variance`.
The passed batch should be a dict instance and contain ``mean_pred``,
``mean_target``, ``logvar_pred``, ``logvar_target``.
When call ``feed`` operation, only ``reals`` mode is needed,

The calculation of KLD can be formulated as:
..math::
:nowarp:
\begin{eqnarray}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} - \frac{1}{2}
\end{eqnarray}
where `p` and `q` denote target and predicted distribution respectively.

Args:
num_images (int): The number of samples to be tested.
log_base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
reduction (string, optional): Specifies the reduction to apply to the
output. Support ``'batchmean'``, ``'sum'`` and ``'mean'``. If
``reduction == 'batchmean'``, the sum of the output will be divided
by batchsize. If ``reduction == 'sum'``, the output will be summed.
If ``reduction == 'mean'``, the output will be divided by the
number of elements in the output. Defaults to ``'batchmean'``.

"""
name = 'GaussianKLD'

def __init__(self, num_images, log_base='e', reduction='batchmean'):
super().__init__(num_images, image_shape=None)
assert reduction in [
'sum', 'batchmean', 'mean'
], ('We only support reduction for \'batchmean\', \'sum\' '
'and \'mean\'')
assert log_base in ['e', '2'
], ('We only support log_base for \'e\' and \'2\'')
self.log_base = log_base
self.reduction = reduction
self.num_fake_feeded = self.num_images

def prepare(self):
"""Prepare for evaluating models with this metric."""
self.KLD = []
self.num_real_feeded = 0

@torch.no_grad()
def feed_op(self, batch, mode):
"""Feed data to the metric.

Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if mode == 'fakes':
return
assert isinstance(batch, dict), ('To calculate GaussianKLD loss, a '
'dict contains probabilistic '
'parameters is required.')
# check required keys
require_keys = [
'mean_pred', 'mean_target', 'logvar_pred', 'logvar_target'
]
if any([k not in batch for k in require_keys]):
raise KeyError(f'The input dict must require {require_keys} at '
'the same time. But keys in the given dict are '
f'{batch.keys()}. Some of the requirements are '
'missing.')
kld = gaussian_kld(
batch['mean_target'],
batch['mean_pred'],
batch['logvar_target'],
batch['logvar_pred'],
base=self.log_base)
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(kld) for _ in range(ws)]
dist.all_gather(placeholder, kld)
kld = torch.cat(placeholder, dim=0)

# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
self.KLD.append(kld.cpu())
LeoXing1996 marked this conversation as resolved.
Show resolved Hide resolved

@torch.no_grad()
def summary(self):
"""Summarize the results.

Returns:
dict | list: Summarized results.
"""
kld = torch.cat(self.KLD, dim=0)
assert kld.shape[0] >= self.num_images
kld_np = kld.numpy()
if self.reduction == 'sum':
kld_result = np.sum(kld_np)
elif self.reduction == 'mean':
kld_result = np.mean(kld_np)
else:
kld_result = np.sum(kld_np) / kld_np.shape[0]
self._result_str = (f'{kld_result:.4f}')
return kld_result
75 changes: 70 additions & 5 deletions tests/test_cores/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch

from mmgen.core.evaluation.metric_utils import extract_inception_features
from mmgen.core.evaluation.metrics import FID, IS, MS_SSIM, PPL, PR, SWD
from mmgen.core.evaluation.metrics import (FID, IS, MS_SSIM, PPL, PR, SWD,
GaussianKLD)
from mmgen.datasets import UnconditionalImageDataset, build_dataloader
from mmgen.models import build_model
from mmgen.models.architectures import InceptionV3
Expand Down Expand Up @@ -65,6 +66,69 @@ def test_ms_ssim():
assert ssim_result < 1


def test_kld_gaussian():
# we only test at bz = 1 to test the numerical accuracy
# due to the time and memory cost
tar_shape = [2, 3, 4, 4]
mean1, mean2 = torch.rand(*tar_shape, 1), torch.rand(*tar_shape, 1)
# var1, var2 = torch.rand(2, 3, 4, 4, 1), torch.rand(2, 3, 4, 4, 1)
var1 = torch.randint(1, 3, (*tar_shape, 1))
var2 = torch.randint(1, 3, (*tar_shape, 1))

def pdf(x, mean, var):
return (1 / np.sqrt(2 * np.pi * var) * torch.exp(-(x - mean)**2 /
(2 * var)))

delta = 0.0001
indice = torch.arange(-5, 5, delta).repeat(*mean1.shape)
p = pdf(indice, mean1, var1) # pdf of target distribution
q = pdf(indice, mean2, var2) # pdf of predicted distribution

kld_manually = (p * torch.log(p / q) * delta).sum(dim=(1, 2, 3, 4)).mean()

data = dict(
mean_pred=mean2,
mean_target=mean1,
logvar_pred=torch.log(var2),
logvar_target=torch.log(var1))

metric = GaussianKLD(2)
metric.prepare()
metric.feed(data, 'reals')
kld = metric.summary()
# this is a quite loose limitation for we cannot choose delta which is
# small enough for precise kld calculation
np.testing.assert_almost_equal(kld, kld_manually, decimal=1)
# assert (kld - kld_manually < 1e-1).all()

metric_base_2 = GaussianKLD(2, log_base='2')
metric_base_2.prepare()
metric_base_2.feed(data, 'reals')
kld_base_2 = metric_base_2.summary()
np.testing.assert_almost_equal(kld_base_2, kld / np.log(2), decimal=4)
# assert kld_base_2 == kld / np.log(2)

# test wrong log_base
with pytest.raises(AssertionError):
GaussianKLD(2, log_base='10')

# test other reduction --> mean
metric = GaussianKLD(2, reduction='mean')
metric.prepare()
metric.feed(data, 'reals')
kld = metric.summary()

# test other reduction --> sum
metric = GaussianKLD(2, reduction='sum')
metric.prepare()
metric.feed(data, 'reals')
kld = metric.summary()

# test other reduction --> error
with pytest.raises(AssertionError):
metric = GaussianKLD(2, reduction='none')


class TestExtractInceptionFeat:

@classmethod
Expand Down Expand Up @@ -98,8 +162,9 @@ def test_extr_inception_feat(self):

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_extr_inception_feat_cuda(self):
inception = torch.nn.DataParallel(self.inception)
feat = extract_inception_features(self.data_loader, inception, 5)
# inception = torch.nn.DataParallel(self.inception)
feat = extract_inception_features(self.data_loader,
self.inception.cuda(), 5)
assert feat.shape[0] == 5

@torch.no_grad()
Expand All @@ -113,8 +178,8 @@ def test_with_tero_implement(self):
# Tero implementation
net = torch.jit.load(
'./work_dirs/cache/inception-2015-12-05.pt').eval().cuda()
net = torch.nn.DataParallel(net)
feature_tero = net(img, return_features=True)
# net = torch.nn.DataParallel(net)
feature_tero = net(img.cuda(), return_features=True)

print(feature_ours.shape)

Expand Down