From 736a71baac1a597408f237d2f9a026215b947e7d Mon Sep 17 00:00:00 2001 From: LeoXing Date: Wed, 15 Sep 2021 10:26:15 +0800 Subject: [PATCH] fix by comment --- mmgen/core/evaluation/evaluation.py | 6 +++--- mmgen/core/evaluation/metrics.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mmgen/core/evaluation/evaluation.py b/mmgen/core/evaluation/evaluation.py index 061003271..3adb650a4 100644 --- a/mmgen/core/evaluation/evaluation.py +++ b/mmgen/core/evaluation/evaluation.py @@ -280,7 +280,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger, 'channels in the first, ' 'not % d' % reals.shape[1]) if reals.shape[1] == 1: - reals = torch.cat([reals] * 3, dim=1) + reals = reals.repeat(1, 3, 1, 1) num_feed = metric.feed(reals, 'reals') if num_feed <= 0: break @@ -342,7 +342,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger, for metric in probabilistic_metrics: metric.prepare() pbar = mmcv.ProgressBar(len(data_loader)) - # here we assert probabilistic model have reconstruction mode + # here we assume probabilistic model have reconstruction mode kwargs['mode'] = 'reconstruction' for data in data_loader: # key for unconditional GAN @@ -362,7 +362,7 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger, 'channels in the first, ' 'not % d' % reals.shape[1]) if reals.shape[1] == 1: - reals = torch.cat([reals] * 3, dim=1) + reals = reals.repeat(1, 3, 1, 1) prob_dict = model(reals, return_loss=False, **kwargs) num_feed = metric.feed(prob_dict, 'reals') diff --git a/mmgen/core/evaluation/metrics.py b/mmgen/core/evaluation/metrics.py index 98da10f1f..5b0328f58 100644 --- a/mmgen/core/evaluation/metrics.py +++ b/mmgen/core/evaluation/metrics.py @@ -385,7 +385,7 @@ def feed(self, batch, mode): operation in 'feed_op' function. Args: - batch (Tensor|dict[Tensor]): Images or dict to be feeded into + batch (Tensor | dict): Images or dict to be feeded into metric object. If ``Tensor`` is passed, the order of ``Tensor`` should be "NCHW". If ``dict`` is passed, each term in the ``dict`` are ``Tensor`` with order "NCHW". @@ -1405,11 +1405,12 @@ def __init__(self, num_images, base='e', reduction='batchmean'): '2'], ('We only support log_base for \'e\' and \'2\'') self.reduction = reduction self.num_fake_feeded = self.num_images - self.kld = partial(gaussian_kld, weight=1, reduction='none', base=base) + self.cal_kld = partial( + gaussian_kld, weight=1, reduction='none', base=base) def prepare(self): """Prepare for evaluating models with this metric.""" - self.KLD = [] + self.kld = [] self.num_real_feeded = 0 @torch.no_grad() @@ -1434,8 +1435,8 @@ def feed_op(self, batch, mode): 'the same time. But keys in the given dict are ' f'{batch.keys()}. Some of the requirements are ' 'missing.') - kld = self.kld(batch['mean_target'], batch['mean_pred'], - batch['logvar_target'], batch['logvar_pred']) + kld = self.cal_kld(batch['mean_target'], batch['mean_pred'], + batch['logvar_target'], batch['logvar_pred']) if dist.is_initialized(): ws = dist.get_world_size() placeholder = [torch.zeros_like(kld) for _ in range(ws)] @@ -1445,7 +1446,7 @@ def feed_op(self, batch, mode): # in distributed training, we only collect features at rank-0. if (dist.is_initialized() and dist.get_rank() == 0) or not dist.is_initialized(): - self.KLD.append(kld.cpu()) + self.kld.append(kld.cpu()) @torch.no_grad() def summary(self): @@ -1454,7 +1455,7 @@ def summary(self): Returns: dict | list: Summarized results. """ - kld = torch.cat(self.KLD, dim=0) + kld = torch.cat(self.kld, dim=0) assert kld.shape[0] >= self.num_images kld_np = kld.numpy() if self.reduction == 'sum':