Skip to content

Commit

Permalink
fix by comment
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Sep 15, 2021
1 parent ec1ba99 commit 736a71b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
6 changes: 3 additions & 3 deletions mmgen/core/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
reals = reals.repeat(1, 3, 1, 1)
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
Expand Down Expand Up @@ -342,7 +342,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
for metric in probabilistic_metrics:
metric.prepare()
pbar = mmcv.ProgressBar(len(data_loader))
# here we assert probabilistic model have reconstruction mode
# here we assume probabilistic model have reconstruction mode
kwargs['mode'] = 'reconstruction'
for data in data_loader:
# key for unconditional GAN
Expand All @@ -362,7 +362,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
reals = reals.repeat(1, 3, 1, 1)

prob_dict = model(reals, return_loss=False, **kwargs)
num_feed = metric.feed(prob_dict, 'reals')
Expand Down
15 changes: 8 additions & 7 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def feed(self, batch, mode):
operation in 'feed_op' function.
Args:
batch (Tensor|dict[Tensor]): Images or dict to be feeded into
batch (Tensor | dict): Images or dict to be feeded into
metric object. If ``Tensor`` is passed, the order of ``Tensor``
should be "NCHW". If ``dict`` is passed, each term in the
``dict`` are ``Tensor`` with order "NCHW".
Expand Down Expand Up @@ -1405,11 +1405,12 @@ def __init__(self, num_images, base='e', reduction='batchmean'):
'2'], ('We only support log_base for \'e\' and \'2\'')
self.reduction = reduction
self.num_fake_feeded = self.num_images
self.kld = partial(gaussian_kld, weight=1, reduction='none', base=base)
self.cal_kld = partial(
gaussian_kld, weight=1, reduction='none', base=base)

def prepare(self):
"""Prepare for evaluating models with this metric."""
self.KLD = []
self.kld = []
self.num_real_feeded = 0

@torch.no_grad()
Expand All @@ -1434,8 +1435,8 @@ def feed_op(self, batch, mode):
'the same time. But keys in the given dict are '
f'{batch.keys()}. Some of the requirements are '
'missing.')
kld = self.kld(batch['mean_target'], batch['mean_pred'],
batch['logvar_target'], batch['logvar_pred'])
kld = self.cal_kld(batch['mean_target'], batch['mean_pred'],
batch['logvar_target'], batch['logvar_pred'])
if dist.is_initialized():
ws = dist.get_world_size()
placeholder = [torch.zeros_like(kld) for _ in range(ws)]
Expand All @@ -1445,7 +1446,7 @@ def feed_op(self, batch, mode):
# in distributed training, we only collect features at rank-0.
if (dist.is_initialized()
and dist.get_rank() == 0) or not dist.is_initialized():
self.KLD.append(kld.cpu())
self.kld.append(kld.cpu())

@torch.no_grad()
def summary(self):
Expand All @@ -1454,7 +1455,7 @@ def summary(self):
Returns:
dict | list: Summarized results.
"""
kld = torch.cat(self.KLD, dim=0)
kld = torch.cat(self.kld, dim=0)
assert kld.shape[0] >= self.num_images
kld_np = kld.numpy()
if self.reduction == 'sum':
Expand Down

0 comments on commit 736a71b

Please sign in to comment.