Skip to content

Commit

Permalink
[Fix] Fix bug of real samples number calculation in DDP evaluation (#150
Browse files Browse the repository at this point in the history
)

* fix bug of real samples number calculation in DDP evaluation

* remove comments

* avoid iter dataloader when do not need real samples

* fix bug

* fix bug in eval_trans
  • Loading branch information
LeoXing1996 committed Nov 23, 2021
1 parent 70c7dea commit 351dcaf
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
46 changes: 33 additions & 13 deletions mmgen/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,22 @@ def after_train_iter(self, runner):

runner.model.eval()

batch_size = self.dataloader.batch_size
rank, ws = get_dist_info()
total_batch_size = batch_size * ws

# sample real images
max_num_images = max(metric.num_images for metric in self.metrics)
for metric in self.metrics:
if metric.num_real_feeded >= metric.num_real_need:
continue
mmcv.print_log(f'Feed reals to {metric.name} metric.', 'mmgen')
# feed in real images
max_real_num_images = max(metric.num_images - metric.num_real_feeded
for metric in self.metrics)
# define mmcv progress bar
if rank == 0 and max_real_num_images > 0:
mmcv.print_log(
f'Sample {max_real_num_images} real images for evaluation',
'mmgen')
pbar = mmcv.ProgressBar(max_real_num_images)

if max_real_num_images > 0:
for data in self.dataloader:
# key for unconditional GAN
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
Expand All @@ -242,16 +249,29 @@ def after_train_iter(self, runner):
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')
num_feed = metric.feed(reals, 'reals')

if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = reals.repeat(1, 3, 1, 1)

num_feed = 0
for metric in self.metrics:
num_feed_ = metric.feed(reals, 'reals')
num_feed = max(num_feed_, num_feed)

if num_feed <= 0:
break

mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
'mmgen')
batch_size = self.dataloader.batch_size
if rank == 0:
pbar.update(num_feed)

rank, ws = get_dist_info()
total_batch_size = batch_size * ws
max_num_images = max(metric.num_images for metric in self.metrics)
if rank == 0:
mmcv.print_log(
f'Sample {max_num_images} fake images for evaluation', 'mmgen')

# define mmcv progress bar
if rank == 0:
Expand Down
11 changes: 9 additions & 2 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def feed(self, batch, mode):
mode (str): Mark the batch as real or fake images. Value can be
'reals' or 'fakes',
"""
_, ws = get_dist_info()
if mode == 'reals':
if self.num_real_feeded == self.num_real_need:
return 0
Expand All @@ -406,8 +407,11 @@ def feed(self, batch, mode):
end = min(batch_size,
self.num_real_need - self.num_real_feeded)
batch_to_feed = batch[:end, ...]

global_end = min(batch_size * ws,
self.num_real_need - self.num_real_feeded)
self.feed_op(batch_to_feed, mode)
self.num_real_feeded += end
self.num_real_feeded += global_end
return end

elif mode == 'fakes':
Expand All @@ -420,8 +424,11 @@ def feed(self, batch, mode):
batch_to_feed = {k: v[:end, ...] for k, v in batch.items()}
else:
batch_to_feed = batch[:end, ...]

global_end = min(batch_size * ws,
self.num_fake_need - self.num_fake_feeded)
self.feed_op(batch_to_feed, mode)
self.num_fake_feeded += end
self.num_fake_feeded += global_end
return end
else:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions tools/utils/translation_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def single_gpu_evaluation(model,
# select key to fetch fake images
target_domain = basic_table_info['target_domain']
source_domain = basic_table_info['source_domain']
# if no images, `num_exist` should be zero
# if no images, `num_needed` should be zero
data_loader_iter = iter(data_loader)
for begin in range(num_exist, num_needed, batch_size):
for begin in range(0, num_needed, batch_size):
end = min(begin + batch_size, max_num_images)
# for translation model, we feed them images from dataloader
data_batch = next(data_loader_iter)
Expand Down

0 comments on commit 351dcaf

Please sign in to comment.