Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support KLD metric and support evaluation for probabilistic models #108

Merged
merged 6 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions mmgen/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,21 @@
from .evaluation import (make_metrics_table, make_vanilla_dataloader,
single_gpu_evaluation, single_gpu_online_evaluation)
from .metric_utils import slerp
from .metrics import IS, MS_SSIM, PR, SWD, ms_ssim, sliced_wasserstein
from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim,
sliced_wasserstein)

__all__ = [
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'single_gpu_evaluation',
'single_gpu_online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader'
'MS_SSIM',
'SWD',
'ms_ssim',
'sliced_wasserstein',
'single_gpu_evaluation',
'single_gpu_online_evaluation',
'PR',
'IS',
'slerp',
'GenerativeEvalHook',
'make_metrics_table',
'make_vanilla_dataloader',
'GaussianKLD',
]
43 changes: 41 additions & 2 deletions mmgen/core/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,26 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
of the metric table include training configuration and ckpt.
kwargs (dict): Other arguments.
"""
# separate metrics into special metrics and vanilla metrics.
# separate metrics into special metrics, probabilistic metrics and vanilla
# metrics.
# For vanilla metrics, images are generated in a random way, and are
# shared by these metrics. For special metrics like 'PPL', images are
# generated in a metric-special way and not shared between different
# metrics.
# For probabilistic metrics like 'GaussianKLD', they do not
# receive images but receive a dict with cooresponding probabilistic
# parameter. To make the model return probabilistic

special_metrics = []
probabilistic_metrics = []
vanilla_metrics = []
special_metric_name = ['PPL']
probabilistic_metric_name = ['GaussianKLD']
for metric in metrics:
if metric.name in special_metric_name:
special_metrics.append(metric)
elif metric.name in probabilistic_metric_name:
probabilistic_metrics.append(metric)
else:
vanilla_metrics.append(metric)

Expand Down Expand Up @@ -271,7 +280,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
reals = reals.repeat(1, 3, 1, 1)
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
Expand Down Expand Up @@ -329,6 +338,36 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
# finish the pbar stdout
sys.stdout.write('\n')

# feed probabilistic metric
for metric in probabilistic_metrics:
metric.prepare()
pbar = mmcv.ProgressBar(len(data_loader))
# here we assume probabilistic model have reconstruction mode
kwargs['mode'] = 'reconstruction'
for data in data_loader:
# key for unconditional GAN
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')

if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = reals.repeat(1, 3, 1, 1)

prob_dict = model(reals, return_loss=False, **kwargs)
num_feed = metric.feed(prob_dict, 'reals')
pbar.update(num_feed)

for metric in metrics:
metric.summary()

Expand Down
139 changes: 131 additions & 8 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial

import mmcv
import numpy as np
Expand All @@ -18,6 +19,7 @@
from mmgen.models.architectures import InceptionV3
from mmgen.models.architectures.common import get_module_device
from mmgen.models.architectures.lpips import PerceptualLoss
from mmgen.models.losses import gaussian_kld
from mmgen.utils import MMGEN_CACHE_DIR
from mmgen.utils.io_utils import download_from_url
from ..registry import METRICS
Expand Down Expand Up @@ -383,18 +385,28 @@ def feed(self, batch, mode):
operation in 'feed_op' function.

Args:
batch (Tensor): Images feeded into metric object with order "NCHW"
and range [-1, 1].
batch (Tensor | dict): Images or dict to be feeded into
metric object. If ``Tensor`` is passed, the order of ``Tensor``
should be "NCHW". If ``dict`` is passed, each term in the
``dict`` are ``Tensor`` with order "NCHW".
mode (str): Mark the batch as real or fake images. Value can be
'reals' or 'fakes',
"""
if mode == 'reals':
if self.num_real_feeded == self.num_real_need:
return 0

batch_size = batch.shape[0]
end = min(batch_size, self.num_real_need - self.num_real_feeded)
self.feed_op(batch[:end, :, :, :], mode)
if isinstance(batch, dict):
batch_size = [v for v in batch.values()][0].shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_size = batch.shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = batch[:end, ...]
self.feed_op(batch_to_feed, mode)
self.num_real_feeded += end
return end

Expand All @@ -404,13 +416,17 @@ def feed(self, batch, mode):

batch_size = batch.shape[0]
end = min(batch_size, self.num_fake_need - self.num_fake_feeded)
self.feed_op(batch[:end, :, :, :], mode)
if isinstance(batch, dict):
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_to_feed = batch[:end, ...]
self.feed_op(batch_to_feed, mode)
self.num_fake_feeded += end
return end
else:
raise ValueError(
f"The expected mode should be set to 'reals' or 'fakes,\
but got '{mode}'")
'The expected mode should be set to \'reals\' or \'fakes\','
f'but got \'{mode}\'')

def check(self):
"""Check the numbers of image."""
Expand Down Expand Up @@ -1343,3 +1359,110 @@ def __next__(self):

self.idx += 1
return image


class GaussianKLD(Metric):
r"""Gaussian KLD (Kullback-Leibler divergence) metric. We calculate the
KLD between two gaussian distribution via `mean` and `log_variance`.
The passed batch should be a dict instance and contain ``mean_pred``,
``mean_target``, ``logvar_pred``, ``logvar_target``.
When call ``feed`` operation, only ``reals`` mode is needed,

The calculation of KLD can be formulated as:
..math::
:nowarp:
\begin{eqnarray}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} - \frac{1}{2}
\end{eqnarray}
where `p` and `q` denote target and predicted distribution respectively.

Args:
num_images (int): The number of samples to be tested.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
reduction (string, optional): Specifies the reduction to apply to the
output. Support ``'batchmean'``, ``'sum'`` and ``'mean'``. If
``reduction == 'batchmean'``, the sum of the output will be divided
by batchsize. If ``reduction == 'sum'``, the output will be summed.
If ``reduction == 'mean'``, the output will be divided by the
number of elements in the output. Defaults to ``'batchmean'``.

"""
name = 'GaussianKLD'

def __init__(self, num_images, base='e', reduction='batchmean'):
super().__init__(num_images, image_shape=None)
assert reduction in [
'sum', 'batchmean', 'mean'
], ('We only support reduction for \'batchmean\', \'sum\' '
'and \'mean\'')
assert base in ['e',
'2'], ('We only support log_base for \'e\' and \'2\'')
self.reduction = reduction
self.num_fake_feeded = self.num_images
self.cal_kld = partial(
gaussian_kld, weight=1, reduction='none', base=base)

def prepare(self):
"""Prepare for evaluating models with this metric."""
self.kld = []
self.num_real_feeded = 0

@torch.no_grad()
def feed_op(self, batch, mode):
"""Feed data to the metric.

Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if mode == 'fakes':
return
assert isinstance(batch, dict), ('To calculate GaussianKLD loss, a '
'dict contains probabilistic '
'parameters is required.')
# check required keys
require_keys = [
'mean_pred', 'mean_target', 'logvar_pred', 'logvar_target'
]
if any([k not in batch for k in require_keys]):
raise KeyError(f'The input dict must require {require_keys} at '
'the same time. But keys in the given dict are '
f'{batch.keys()}. Some of the requirements are '
'missing.')
kld = self.cal_kld(batch['mean_target'], batch['mean_pred'],
batch['logvar_target'], batch['logvar_pred'])
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(kld) for _ in range(ws)]
dist.all_gather(placeholder, kld)
kld = torch.cat(placeholder, dim=0)

# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
self.kld.append(kld.cpu())

@torch.no_grad()
def summary(self):
"""Summarize the results.

Returns:
dict | list: Summarized results.
"""
kld = torch.cat(self.kld, dim=0)
assert kld.shape[0] >= self.num_images
kld_np = kld.numpy()
if self.reduction == 'sum':
kld_result = np.sum(kld_np)
elif self.reduction == 'mean':
kld_result = np.mean(kld_np)
else:
kld_result = np.sum(kld_np) / kld_np.shape[0]
self._result_str = (f'{kld_result:.4f}')
return kld_result
5 changes: 3 additions & 2 deletions mmgen/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
r1_gradient_penalty_loss)
from .gan_loss import GANLoss
from .gen_auxiliary_loss import GeneratorPathRegularizer, gen_path_regularizer
from .pixelwise_loss import L1Loss, MSELoss
from .pixelwise_loss import L1Loss, MSELoss, gaussian_kld

__all__ = [
'GANLoss', 'DiscShiftLoss', 'disc_shift_loss', 'gradient_penalty_loss',
'GradientPenaltyLoss', 'R1GradientPenalty', 'r1_gradient_penalty_loss',
'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss'
'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss',
'gaussian_kld'
]
45 changes: 44 additions & 1 deletion mmgen/models/losses/pixelwise_loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmgen.models.builder import MODULES
from .utils import weighted_loss

_reduction_modes = ['none', 'mean', 'sum']
_reduction_modes = ['none', 'mean', 'sum', 'batchmean']


@weighted_loss
Expand Down Expand Up @@ -36,6 +38,47 @@ def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')


@weighted_loss
def gaussian_kld(mean_1, mean_2, logvar_1, logvar_2, base='e'):
r"""Calculate KLD (Kullback-Leibler divergence) of two gaussian
distribution.
To be noted that in this function, KLD is calcuated in base `e`.

.. math::
:nowarp:
\begin{eqnarray}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} - \frac{1}{2}
\end{eqnarray}

Args:
mean_1 (torch.Tensor): Mean of the first (or the target) distribution.
mean_2 (torch.Tensor): Mean of the second (or the predicted)
distribution.
logvar_1 (torch.Tensor): Log variance of the first (or the target)
distribution
logvar_2 (torch.Tensor): Log variance of the second (or the predicted)
distribution.
base (str, optional): The log base of calculated KLD. Support ``'e'``
and ``'2'``. Defaults to ``'e'``.

Returns:
torch.Tensor: KLD between two given distribution.
"""
if base not in ['e', '2']:
raise ValueError('Only support 2 and e for log base, but receive '
f'{base}')
kld = 0.5 * (-1.0 + logvar_2 - logvar_1 + torch.exp(logvar_1 - logvar_2) +
((mean_1 - mean_2)**2) * torch.exp(-logvar_2))
if base == '2':
return kld / np.log(2)
return kld


@MODULES.register_module()
class MSELoss(nn.Module):
"""MSE loss.
Expand Down
5 changes: 4 additions & 1 deletion mmgen/models/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ def reduce_loss(loss, reduction):

Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
reduction (str): Options are "none", "mean", "sum" and "batchmean".

Return:
Tensor: Reduced loss tensor.
"""
if reduction == 'batchmean':
return loss.sum() / loss.shape[0]

reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
Expand Down
Loading