Skip to content

Commit

Permalink
Merge pull request #23 from LeoXing1996/fix_pbar_and_graryscale_saving
Browse files Browse the repository at this point in the history
[Enhancement] Add pbar to offline evaluation + Fix bug when save or evaluate grayscale images
  • Loading branch information
nbei committed May 12, 2021
2 parents a5e9a94 + d8168bc commit 54fa45d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
36 changes: 35 additions & 1 deletion mmgen/core/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,20 @@ def single_gpu_evaluation(model,
sample_model=basic_table_info['sample_model'],
**kwargs)
pbar.update(end - begin)

# save as three-channel
if fakes.size(1) == 3:
fakes = fakes[:, [2, 1, 0], ...]
elif fakes.size(1) == 1:
fakes = torch.cat([fakes] * 3, dim=1)
else:
raise RuntimeError('Generated images must have one or three '
'channels in the first dimension, '
'not %d' % fakes.size(1))

for i in range(end - begin):
images = fakes[i:i + 1]
images = ((images + 1) / 2)
images = images[:, [2, 1, 0], ...]
images = images.clamp_(0, 1)
image_name = str(begin + i) + '.png'
save_image(images, os.path.join(samples_path, image_name))
Expand All @@ -153,19 +163,29 @@ def single_gpu_evaluation(model,
for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
# prepare for pbar
total_need = metric.num_real_need + metric.num_fake_need
pbar = mmcv.ProgressBar(total_need)
# feed in real images
for data in data_loader:
reals = data['real_img']
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
num_left = metric.feed(reals, 'reals')
pbar.update(reals.shape[0])
if num_left <= 0:
break
# feed in fake images
for data in fake_dataloader:
fakes = data['real_img']
if fakes.shape[1] == 1:
fakes = torch.cat([fakes] * 3, dim=1)
num_left = metric.feed(fakes, 'fakes')
pbar.update(fakes.shape[0])
if num_left <= 0:
break
metric.summary()
sys.stdout.write('\n')
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'], metrics)
Expand Down Expand Up @@ -220,6 +240,13 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
# feed in real images
for data in data_loader:
reals = data['real_img']

if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = torch.cat([reals] * 3, dim=1)
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
Expand All @@ -244,6 +271,13 @@ def single_gpu_online_evaluation(model, data_loader, metrics, logger,
return_loss=False,
sample_model=basic_table_info['sample_model'],
**kwargs)

if fakes.shape[1] not in [1, 3]:
raise RuntimeError('fakes images should have one or three '
'channels in the first, '
'not % d' % fakes.shape[1])
if fakes.shape[1] == 1:
fakes = torch.cat([fakes] * 3, dim=1)
pbar.update(end - begin)
fakes = fakes[:end - begin]

Expand Down
4 changes: 3 additions & 1 deletion mmgen/core/hooks/visualize_training_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def after_train_iter(self, runner):
filename = self.filename_tmpl.format(runner.iter + 1)
if self.rerange:
imgs = ((imgs + 1) / 2)
if self.bgr2rgb:
if self.bgr2rgb and imgs.size(1) == 3:
imgs = imgs[:, [2, 1, 0], ...]
if imgs.size(1) == 1:
imgs = torch.cat([imgs, imgs, imgs], dim=1)
imgs = imgs.clamp_(0, 1)

mmcv.mkdir_or_exist(osp.join(runner.work_dir, self.output_dir))
Expand Down

0 comments on commit 54fa45d

Please sign in to comment.