Skip to content

Commit

Permalink
[Feature] Support Evaluation Hook of Translation Models (#127)
Browse files Browse the repository at this point in the history
* support translation evaluation during training

* fix lint

* fix priority for eval hook

* fix pr and add evaluation to configs

* solve conflict

Co-authored-by: yangyifei <PJLAB\yangyifei@shai14001042l.pjlab.org>
  • Loading branch information
plyfager and yangyifei committed Oct 15, 2021
1 parent 6eb7045 commit de6cf63
Show file tree
Hide file tree
Showing 14 changed files with 343 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,24 @@
exp_name = 'cyclegan_summer2winter_id0'
work_dir = f'./work_dirs/experiments/{exp_name}'
# testA: 309, testB:238
num_images = 238
metrics = dict(
FID=dict(type='FID', num_images=238, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=238,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,24 @@
workflow = [('train', 1)]
exp_name = 'cyclegan_horse2zebra_id0'
work_dir = f'./work_dirs/experiments/{exp_name}'
num_images = 140
metrics = dict(
FID=dict(type='FID', num_images=140, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=140,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
18 changes: 16 additions & 2 deletions configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_80k_facades.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,24 @@
workflow = [('train', 1)]
exp_name = 'cyclegan_facades_id0'
work_dir = f'./work_dirs/experiments/{exp_name}'
num_images = 106
metrics = dict(
FID=dict(type='FID', num_images=106, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=106,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,24 @@
exp_name = 'cyclegan_summer2winter'
work_dir = f'./work_dirs/experiments/{exp_name}'
# testA: 309, testB:238
num_images = 238
metrics = dict(
FID=dict(type='FID', num_images=238, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=238,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
20 changes: 16 additions & 4 deletions configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
keys=[f'img_{domain_a}', f'img_{domain_b}'],
meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
]

test_pipeline = [
dict(
type='LoadImageFromFile',
Expand Down Expand Up @@ -109,7 +108,6 @@
keys=[f'img_{domain_a}', f'img_{domain_b}'],
meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
]

data = dict(
train=dict(
dataroot=dataroot,
Expand Down Expand Up @@ -151,10 +149,24 @@
exp_name = 'cyclegan_horse2zebra'
work_dir = f'./work_dirs/experiments/{exp_name}'
# testA 120, testB 140
num_images = 140
metrics = dict(
FID=dict(type='FID', num_images=140, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=140,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
18 changes: 16 additions & 2 deletions configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,24 @@
workflow = [('train', 1)]
exp_name = 'cyclegan_facades'
work_dir = f'./work_dirs/experiments/{exp_name}'
num_images = 106
metrics = dict(
FID=dict(type='FID', num_images=106, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=106,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
20 changes: 17 additions & 3 deletions configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_aerial2maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
dataroot = 'data/paired/maps'
data = dict(
train=dict(dataroot=dataroot, pipeline=train_pipeline),
val=dict(dataroot=dataroot, pipeline=test_pipeline),
val=dict(dataroot=dataroot, pipeline=test_pipeline, testdir='val'),
test=dict(dataroot=dataroot, pipeline=test_pipeline, testdir='val'))

# optimizer
Expand Down Expand Up @@ -107,10 +107,24 @@
workflow = [('train', 1)]
exp_name = 'pix2pix_aerial2map'
work_dir = f'./work_dirs/experiments/{exp_name}'
num_images = 1098
metrics = dict(
FID=dict(type='FID', num_images=1098, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=1098,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
20 changes: 17 additions & 3 deletions configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_maps2aerial.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
dataroot = 'data/paired/maps'
data = dict(
train=dict(dataroot=dataroot, pipeline=train_pipeline),
val=dict(dataroot=dataroot, pipeline=test_pipeline),
val=dict(dataroot=dataroot, pipeline=test_pipeline, testdir='val'),
test=dict(dataroot=dataroot, pipeline=test_pipeline, testdir='val'))
# optimizer
optimizer = dict(
Expand All @@ -106,10 +106,24 @@
workflow = [('train', 1)]
exp_name = 'pix2pix_maps2aerial'
work_dir = f'./work_dirs/experiments/{exp_name}'
num_images = 1098
metrics = dict(
FID=dict(type='FID', num_images=1098, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=1098,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
18 changes: 16 additions & 2 deletions configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,24 @@
workflow = [('train', 1)]
exp_name = 'pix2pix_facades'
work_dir = f'./work_dirs/experiments/{exp_name}'
num_images = 106
metrics = dict(
FID=dict(type='FID', num_images=106, image_shape=(3, 256, 256)),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=106,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
dataroot = 'data/paired/edges2shoes'
data = dict(
train=dict(dataroot=dataroot, pipeline=train_pipeline),
val=dict(dataroot=dataroot, pipeline=test_pipeline),
val=dict(dataroot=dataroot, pipeline=test_pipeline, testdir='val'),
test=dict(dataroot=dataroot, pipeline=test_pipeline, testdir='val'))

# optimizer
Expand Down Expand Up @@ -106,11 +106,24 @@
workflow = [('train', 1)]
exp_name = 'pix2pix_edges2shoes_wo_jitter_flip'
work_dir = f'./work_dirs/experiments/{exp_name}'
num_images = 200
metrics = dict(
FID=dict(
type='FID', num_images=200, image_shape=(3, 256, 256), bgr2rgb=True),
FID=dict(type='FID', num_images=num_images, image_shape=(3, 256, 256)),
IS=dict(
type='IS',
num_images=200,
num_images=num_images,
image_shape=(3, 256, 256),
inception_args=dict(type='pytorch')))

evaluation = dict(
type='TranslationEvalHook',
target_domain=domain_b,
interval=10000,
metrics=[
dict(type='FID', num_images=num_images, bgr2rgb=True),
dict(
type='IS',
num_images=num_images,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
2 changes: 1 addition & 1 deletion mmgen/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def train_model(model,
val_dataloader = build_dataloader(
val_dataset, dist=distributed, **val_loader_cfg)
eval_cfg = deepcopy(cfg.get('evaluation'))
priority = eval_cfg.pop('priority', 'LOW')
eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))
eval_hook = build_from_cfg(eval_cfg, HOOKS)
priority = eval_cfg.pop('priority', 'NORMAL')
runner.register_hook(eval_hook, priority=priority)

# user-defined hooks
Expand Down
19 changes: 5 additions & 14 deletions mmgen/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .eval_hooks import GenerativeEvalHook
from .eval_hooks import GenerativeEvalHook, TranslationEvalHook
from .evaluation import (make_metrics_table, make_vanilla_dataloader,
single_gpu_evaluation, single_gpu_online_evaluation)
from .metric_utils import slerp
from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim,
sliced_wasserstein)

__all__ = [
'MS_SSIM',
'SWD',
'ms_ssim',
'sliced_wasserstein',
'single_gpu_evaluation',
'single_gpu_online_evaluation',
'PR',
'IS',
'slerp',
'GenerativeEvalHook',
'make_metrics_table',
'make_vanilla_dataloader',
'GaussianKLD',
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'single_gpu_evaluation',
'single_gpu_online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD',
'TranslationEvalHook'
]
Loading

0 comments on commit de6cf63

Please sign in to comment.