Skip to content

Commit

Permalink
Merge ce6e4eb into 7e3aacc
Browse files Browse the repository at this point in the history
  • Loading branch information
plyfager authored Sep 29, 2021
2 parents 7e3aacc + ce6e4eb commit ac7211c
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 17 deletions.
52 changes: 50 additions & 2 deletions configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,51 @@
keys=[f'img_{domain_a}', f'img_{domain_b}'],
meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
]
test_pipeline = [
dict(
type='LoadImageFromFile',
io_backend='disk',
key=f'img_{domain_a}',
flag='color'),
dict(
type='LoadImageFromFile',
io_backend='disk',
key=f'img_{domain_b}',
flag='color'),
dict(
type='Resize',
keys=[f'img_{domain_a}', f'img_{domain_b}'],
scale=(256, 256),
interpolation='bicubic'),
dict(type='RescaleToZeroOne', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Normalize',
keys=[f'img_{domain_a}', f'img_{domain_b}'],
to_rgb=False,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Collect',
keys=[f'img_{domain_a}', f'img_{domain_b}'],
meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
]
data = dict(
train=dict(
dataroot=dataroot,
pipeline=train_pipeline,
domain_a=domain_a,
domain_b=domain_b),
val=dict(dataroot=dataroot, domain_a=domain_a, domain_b=domain_b),
test=dict(dataroot=dataroot, domain_a=domain_a, domain_b=domain_b))
val=dict(
dataroot=dataroot,
domain_a=domain_a,
domain_b=domain_b,
pipeline=test_pipeline),
test=dict(
dataroot=dataroot,
domain_a=domain_a,
domain_b=domain_b,
pipeline=test_pipeline))

optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
Expand Down Expand Up @@ -115,3 +152,14 @@
metrics = dict(
FID=dict(type='FID', num_images=140, image_shape=(3, 256, 256)),
IS=dict(type='IS', num_images=140, image_shape=(3, 256, 256)))

# inception_pkl = None
evaluation = dict(
type='TranslationEvalHook',
target_domain='zebra',
interval=200,
metrics=[
dict(type='FID', num_images=140, bgr2rgb=True),
dict(type='IS', num_images=140)
],
best_metric=['fid', 'is'])
2 changes: 1 addition & 1 deletion mmgen/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def train_model(model,
val_dataloader = build_dataloader(
val_dataset, dist=distributed, **val_loader_cfg)
eval_cfg = deepcopy(cfg.get('evaluation'))
priority = eval_cfg.pop('priority', 'LOW')
eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))
eval_hook = build_from_cfg(eval_cfg, HOOKS)
priority = eval_cfg.pop('priority', 'NORMAL')
runner.register_hook(eval_hook, priority=priority)

# user-defined hooks
Expand Down
19 changes: 5 additions & 14 deletions mmgen/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .eval_hooks import GenerativeEvalHook
from .eval_hooks import GenerativeEvalHook, TranslationEvalHook
from .evaluation import (make_metrics_table, make_vanilla_dataloader,
single_gpu_evaluation, single_gpu_online_evaluation)
from .metric_utils import slerp
from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim,
sliced_wasserstein)

__all__ = [
'MS_SSIM',
'SWD',
'ms_ssim',
'sliced_wasserstein',
'single_gpu_evaluation',
'single_gpu_online_evaluation',
'PR',
'IS',
'slerp',
'GenerativeEvalHook',
'make_metrics_table',
'make_vanilla_dataloader',
'GaussianKLD',
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'single_gpu_evaluation',
'single_gpu_online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD',
'TranslationEvalHook'
]
Loading

0 comments on commit ac7211c

Please sign in to comment.