Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support KLD metric and support evaluation for probabilistic models #108

Merged
merged 6 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix by comment
  • Loading branch information
LeoXing1996 committed Sep 15, 2021
commit 736a71baac1a597408f237d2f9a026215b947e7d
6 changes: 3 additions & 3 deletions mmgen/core/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
reals = reals.repeat(1, 3, 1, 1)
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
Expand Down Expand Up @@ -342,7 +342,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
for metric in probabilistic_metrics:
metric.prepare()
pbar = mmcv.ProgressBar(len(data_loader))
# here we assert probabilistic model have reconstruction mode
# here we assume probabilistic model have reconstruction mode
kwargs['mode'] = 'reconstruction'
for data in data_loader:
# key for unconditional GAN
Expand All @@ -362,7 +362,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
reals = reals.repeat(1, 3, 1, 1)

prob_dict = model(reals, return_loss=False, **kwargs)
num_feed = metric.feed(prob_dict, 'reals')
Expand Down
15 changes: 8 additions & 7 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def feed(self, batch, mode):
operation in 'feed_op' function.

Args:
batch (Tensor|dict[Tensor]): Images or dict to be feeded into
batch (Tensor | dict): Images or dict to be feeded into
metric object. If ``Tensor`` is passed, the order of ``Tensor``
should be "NCHW". If ``dict`` is passed, each term in the
``dict`` are ``Tensor`` with order "NCHW".
Expand Down Expand Up @@ -1405,11 +1405,12 @@ def __init__(self, num_images, base='e', reduction='batchmean'):
'2'], ('We only support log_base for \'e\' and \'2\'')
self.reduction = reduction
self.num_fake_feeded = self.num_images
self.kld = partial(gaussian_kld, weight=1, reduction='none', base=base)
self.cal_kld = partial(
gaussian_kld, weight=1, reduction='none', base=base)

def prepare(self):
"""Prepare for evaluating models with this metric."""
self.KLD = []
self.kld = []
self.num_real_feeded = 0

@torch.no_grad()
Expand All @@ -1434,8 +1435,8 @@ def feed_op(self, batch, mode):
'the same time. But keys in the given dict are '
f'{batch.keys()}. Some of the requirements are '
'missing.')
kld = self.kld(batch['mean_target'], batch['mean_pred'],
batch['logvar_target'], batch['logvar_pred'])
kld = self.cal_kld(batch['mean_target'], batch['mean_pred'],
batch['logvar_target'], batch['logvar_pred'])
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(kld) for _ in range(ws)]
Expand All @@ -1445,7 +1446,7 @@ def feed_op(self, batch, mode):
# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
self.KLD.append(kld.cpu())
self.kld.append(kld.cpu())

@torch.no_grad()
def summary(self):
Expand All @@ -1454,7 +1455,7 @@ def summary(self):
Returns:
dict | list: Summarized results.
"""
kld = torch.cat(self.KLD, dim=0)
kld = torch.cat(self.kld, dim=0)
assert kld.shape[0] >= self.num_images
kld_np = kld.numpy()
if self.reduction == 'sum':
Expand Down