Skip to content

Commit

Permalink
Merge 736a71b into bfd00c3
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Sep 15, 2021
2 parents bfd00c3 + 736a71b commit 1146a8f
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 19 deletions.
19 changes: 15 additions & 4 deletions mmgen/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,21 @@
from .evaluation import (make_metrics_table, make_vanilla_dataloader,
single_gpu_evaluation, single_gpu_online_evaluation)
from .metric_utils import slerp
from .metrics import IS, MS_SSIM, PR, SWD, ms_ssim, sliced_wasserstein
from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim,
sliced_wasserstein)

__all__ = [
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'single_gpu_evaluation',
'single_gpu_online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader'
'MS_SSIM',
'SWD',
'ms_ssim',
'sliced_wasserstein',
'single_gpu_evaluation',
'single_gpu_online_evaluation',
'PR',
'IS',
'slerp',
'GenerativeEvalHook',
'make_metrics_table',
'make_vanilla_dataloader',
'GaussianKLD',
]
43 changes: 41 additions & 2 deletions mmgen/core/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,26 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
of the metric table include training configuration and ckpt.
kwargs (dict): Other arguments.
"""
# separate metrics into special metrics and vanilla metrics.
# separate metrics into special metrics, probabilistic metrics and vanilla
# metrics.
# For vanilla metrics, images are generated in a random way, and are
# shared by these metrics. For special metrics like 'PPL', images are
# generated in a metric-special way and not shared between different
# metrics.
# For probabilistic metrics like 'GaussianKLD', they do not
# receive images but receive a dict with cooresponding probabilistic
# parameter. To make the model return probabilistic

special_metrics = []
probabilistic_metrics = []
vanilla_metrics = []
special_metric_name = ['PPL']
probabilistic_metric_name = ['GaussianKLD']
for metric in metrics:
if metric.name in special_metric_name:
special_metrics.append(metric)
elif metric.name in probabilistic_metric_name:
probabilistic_metrics.append(metric)
else:
vanilla_metrics.append(metric)

Expand Down Expand Up @@ -271,7 +280,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
reals = reals.repeat(1, 3, 1, 1)
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
Expand Down Expand Up @@ -329,6 +338,36 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
# finish the pbar stdout
sys.stdout.write('\n')

# feed probabilistic metric
for metric in probabilistic_metrics:
metric.prepare()
pbar = mmcv.ProgressBar(len(data_loader))
# here we assume probabilistic model have reconstruction mode
kwargs['mode'] = 'reconstruction'
for data in data_loader:
# key for unconditional GAN
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')

if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = reals.repeat(1, 3, 1, 1)

prob_dict = model(reals, return_loss=False, **kwargs)
num_feed = metric.feed(prob_dict, 'reals')
pbar.update(num_feed)

for metric in metrics:
metric.summary()

Expand Down
139 changes: 131 additions & 8 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial

import mmcv
import numpy as np
Expand All @@ -18,6 +19,7 @@
from mmgen.models.architectures import InceptionV3
from mmgen.models.architectures.common import get_module_device
from mmgen.models.architectures.lpips import PerceptualLoss
from mmgen.models.losses import gaussian_kld
from mmgen.utils import MMGEN_CACHE_DIR
from mmgen.utils.io_utils import download_from_url
from ..registry import METRICS
Expand Down Expand Up @@ -383,18 +385,28 @@ def feed(self, batch, mode):
operation in 'feed_op' function.
Args:
batch (Tensor): Images feeded into metric object with order "NCHW"
and range [-1, 1].
batch (Tensor | dict): Images or dict to be feeded into
metric object. If ``Tensor`` is passed, the order of ``Tensor``
should be "NCHW". If ``dict`` is passed, each term in the
``dict`` are ``Tensor`` with order "NCHW".
mode (str): Mark the batch as real or fake images. Value can be
'reals' or 'fakes',
"""
if mode == 'reals':
if self.num_real_feeded == self.num_real_need:
return 0

batch_size = batch.shape[0]
end = min(batch_size, self.num_real_need - self.num_real_feeded)
self.feed_op(batch[:end, :, :, :], mode)
if isinstance(batch, dict):
batch_size = [v for v in batch.values()][0].shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_size = batch.shape[0]
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = batch[:end, ...]
self.feed_op(batch_to_feed, mode)
self.num_real_feeded += end
return end

Expand All @@ -404,13 +416,17 @@ def feed(self, batch, mode):

batch_size = batch.shape[0]
end = min(batch_size, self.num_fake_need - self.num_fake_feeded)
self.feed_op(batch[:end, :, :, :], mode)
if isinstance(batch, dict):
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_to_feed = batch[:end, ...]
self.feed_op(batch_to_feed, mode)
self.num_fake_feeded += end
return end
else:
raise ValueError(
f"The expected mode should be set to 'reals' or 'fakes,\
but got '{mode}'")
'The expected mode should be set to \'reals\' or \'fakes\','
f'but got \'{mode}\'')

def check(self):
"""Check the numbers of image."""
Expand Down Expand Up @@ -1343,3 +1359,110 @@ def __next__(self):

self.idx += 1
return image


class GaussianKLD(Metric):
r"""Gaussian KLD (Kullback-Leibler divergence) metric. We calculate the
KLD between two gaussian distribution via `mean` and `log_variance`.
The passed batch should be a dict instance and contain ``mean_pred``,
``mean_target``, ``logvar_pred``, ``logvar_target``.
When call ``feed`` operation, only ``reals`` mode is needed,
The calculation of KLD can be formulated as:
..math::
:nowarp:
\begin{eqnarray}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} - \frac{1}{2}
\end{eqnarray}
where `p` and `q` denote target and predicted distribution respectively.
Args:
num_images (int): The number of samples to be tested.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
reduction (string, optional): Specifies the reduction to apply to the
output. Support ``'batchmean'``, ``'sum'`` and ``'mean'``. If
``reduction == 'batchmean'``, the sum of the output will be divided
by batchsize. If ``reduction == 'sum'``, the output will be summed.
If ``reduction == 'mean'``, the output will be divided by the
number of elements in the output. Defaults to ``'batchmean'``.
"""
name = 'GaussianKLD'

def __init__(self, num_images, base='e', reduction='batchmean'):
super().__init__(num_images, image_shape=None)
assert reduction in [
'sum', 'batchmean', 'mean'
], ('We only support reduction for \'batchmean\', \'sum\' '
'and \'mean\'')
assert base in ['e',
'2'], ('We only support log_base for \'e\' and \'2\'')
self.reduction = reduction
self.num_fake_feeded = self.num_images
self.cal_kld = partial(
gaussian_kld, weight=1, reduction='none', base=base)

def prepare(self):
"""Prepare for evaluating models with this metric."""
self.kld = []
self.num_real_feeded = 0

@torch.no_grad()
def feed_op(self, batch, mode):
"""Feed data to the metric.
Args:
batch (Tensor): Input tensor.
mode (str): The mode of current data batch. 'reals' or 'fakes'.
"""
if mode == 'fakes':
return
assert isinstance(batch, dict), ('To calculate GaussianKLD loss, a '
'dict contains probabilistic '
'parameters is required.')
# check required keys
require_keys = [
'mean_pred', 'mean_target', 'logvar_pred', 'logvar_target'
]
if any([k not in batch for k in require_keys]):
raise KeyError(f'The input dict must require {require_keys} at '
'the same time. But keys in the given dict are '
f'{batch.keys()}. Some of the requirements are '
'missing.')
kld = self.cal_kld(batch['mean_target'], batch['mean_pred'],
batch['logvar_target'], batch['logvar_pred'])
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(kld) for _ in range(ws)]
dist.all_gather(placeholder, kld)
kld = torch.cat(placeholder, dim=0)

# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
self.kld.append(kld.cpu())

@torch.no_grad()
def summary(self):
"""Summarize the results.
Returns:
dict | list: Summarized results.
"""
kld = torch.cat(self.kld, dim=0)
assert kld.shape[0] >= self.num_images
kld_np = kld.numpy()
if self.reduction == 'sum':
kld_result = np.sum(kld_np)
elif self.reduction == 'mean':
kld_result = np.mean(kld_np)
else:
kld_result = np.sum(kld_np) / kld_np.shape[0]
self._result_str = (f'{kld_result:.4f}')
return kld_result
5 changes: 3 additions & 2 deletions mmgen/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
r1_gradient_penalty_loss)
from .gan_loss import GANLoss
from .gen_auxiliary_loss import GeneratorPathRegularizer, gen_path_regularizer
from .pixelwise_loss import L1Loss, MSELoss
from .pixelwise_loss import L1Loss, MSELoss, gaussian_kld

__all__ = [
'GANLoss', 'DiscShiftLoss', 'disc_shift_loss', 'gradient_penalty_loss',
'GradientPenaltyLoss', 'R1GradientPenalty', 'r1_gradient_penalty_loss',
'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss'
'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss',
'gaussian_kld'
]
45 changes: 44 additions & 1 deletion mmgen/models/losses/pixelwise_loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmgen.models.builder import MODULES
from .utils import weighted_loss

_reduction_modes = ['none', 'mean', 'sum']
_reduction_modes = ['none', 'mean', 'sum', 'batchmean']


@weighted_loss
Expand Down Expand Up @@ -36,6 +38,47 @@ def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')


@weighted_loss
def gaussian_kld(mean_1, mean_2, logvar_1, logvar_2, base='e'):
r"""Calculate KLD (Kullback-Leibler divergence) of two gaussian
distribution.
To be noted that in this function, KLD is calcuated in base `e`.
.. math::
:nowarp:
\begin{eqnarray}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\simga_2^2} - \frac{1}{2}
\end{eqnarray}
Args:
mean_1 (torch.Tensor): Mean of the first (or the target) distribution.
mean_2 (torch.Tensor): Mean of the second (or the predicted)
distribution.
logvar_1 (torch.Tensor): Log variance of the first (or the target)
distribution
logvar_2 (torch.Tensor): Log variance of the second (or the predicted)
distribution.
base (str, optional): The log base of calculated KLD. Support ``'e'``
and ``'2'``. Defaults to ``'e'``.
Returns:
torch.Tensor: KLD between two given distribution.
"""
if base not in ['e', '2']:
raise ValueError('Only support 2 and e for log base, but receive '
f'{base}')
kld = 0.5 * (-1.0 + logvar_2 - logvar_1 + torch.exp(logvar_1 - logvar_2) +
((mean_1 - mean_2)**2) * torch.exp(-logvar_2))
if base == '2':
return kld / np.log(2)
return kld


@MODULES.register_module()
class MSELoss(nn.Module):
"""MSE loss.
Expand Down
5 changes: 4 additions & 1 deletion mmgen/models/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ def reduce_loss(loss, reduction):
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
reduction (str): Options are "none", "mean", "sum" and "batchmean".
Return:
Tensor: Reduced loss tensor.
"""
if reduction == 'batchmean':
return loss.sum() / loss.shape[0]

reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
Expand Down
Loading

0 comments on commit 1146a8f

Please sign in to comment.