Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Evaluation Hook of Translation Models #127

Merged
merged 6 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix lint
  • Loading branch information
yangyifei authored and yangyifei committed Sep 28, 2021
commit 432850b80b164d5f6a7c89da58d02cfc66664114
19 changes: 12 additions & 7 deletions configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,16 @@
pipeline=train_pipeline,
domain_a=domain_a,
domain_b=domain_b),
val=dict(dataroot=dataroot, domain_a=domain_a, domain_b=domain_b, pipeline=test_pipeline),
test=dict(dataroot=dataroot, domain_a=domain_a, domain_b=domain_b, pipeline=test_pipeline))
val=dict(
dataroot=dataroot,
domain_a=domain_a,
domain_b=domain_b,
pipeline=test_pipeline),
test=dict(
dataroot=dataroot,
domain_a=domain_a,
domain_b=domain_b,
pipeline=test_pipeline))

optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
Expand Down Expand Up @@ -151,10 +159,7 @@
target_domain='zebra',
interval=200,
metrics=[
dict(
type='FID',
num_images=140,
bgr2rgb=True),
dict(type='FID', num_images=140, bgr2rgb=True),
dict(type='IS', num_images=140)
],
best_metric=['fid', 'is'])
best_metric=['fid', 'is'])
16 changes: 3 additions & 13 deletions mmgen/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,8 @@
sliced_wasserstein)

__all__ = [
'MS_SSIM',
'SWD',
'ms_ssim',
'sliced_wasserstein',
'single_gpu_evaluation',
'single_gpu_online_evaluation',
'PR',
'IS',
'slerp',
'GenerativeEvalHook',
'make_metrics_table',
'make_vanilla_dataloader',
'GaussianKLD',
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'single_gpu_evaluation',
'single_gpu_online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD',
'TranslationEvalHook'
]
19 changes: 13 additions & 6 deletions mmgen/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def after_train_iter(self, runner):
return

runner.model.eval()
source_domain = runner.model.module.get_other_domains(self.target_domain)[0]
source_domain = runner.model.module.get_other_domains(
self.target_domain)[0]
# feed real images
max_num_images = max(metric.num_images for metric in self.metrics)
for metric in self.metrics:
Expand All @@ -547,7 +548,8 @@ def after_train_iter(self, runner):
reals = data[f'img_{self.target_domain}']
# key for conditional GAN
else:
raise KeyError('Cannot found key for images in data_dict. ')
raise KeyError(
'Cannot found key for images in data_dict. ')
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
Expand All @@ -568,19 +570,24 @@ def after_train_iter(self, runner):
# key for translation model
if f'img_{source_domain}' in data:
with torch.no_grad():
output_dict = runner.model(data[f'img_{source_domain}'], test_mode=True, target_domain=self.target_domain, **self.sample_kwargs)
output_dict = runner.model(
data[f'img_{source_domain}'],
test_mode=True,
target_domain=self.target_domain,
**self.sample_kwargs)
fakes = output_dict['target']
# key for conditional GAN
else:
raise KeyError('Cannot found key for images in data_dict. ')
raise KeyError(
'Cannot found key for images in data_dict. ')
# sampling fake images and directly send them to metrics
for metric in self.metrics:
if metric.num_fake_feeded >= metric.num_fake_need:
continue
num_feed = metric.feed(fakes, 'fakes')
if num_feed <= 0:
break

if rank == 0:
pbar.update(total_batch_size)

Expand All @@ -597,7 +604,7 @@ def after_train_iter(self, runner):
# record best metric and save the best ckpt
if self.save_best_ckpt and name in self.best_metric:
self._save_best_ckpt(runner, val, name)

runner.log_buffer.ready = True
runner.model.train()

Expand Down