Skip to content

Commit

Permalink
Merge branch 'master' of github.com:open-mmlab/mmgeneration into bigg…
Browse files Browse the repository at this point in the history
…an-benchmark
  • Loading branch information
yangyifei authored and yangyifei committed Jul 30, 2021
2 parents 6739fd7 + 0c4052e commit f458fa7
Show file tree
Hide file tree
Showing 15 changed files with 1,766 additions and 115 deletions.
8 changes: 6 additions & 2 deletions configs/_base_/datasets/cifar10_noaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
samples_per_gpu=None,
workers_per_gpu=4,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
pipeline=train_pipeline),
type='RepeatDataset',
times=500,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
pipeline=train_pipeline)),
val=dict(
type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
test=dict(
Expand Down
30 changes: 30 additions & 0 deletions configs/_base_/models/biggan_32x32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
model = dict(
type='BasiccGAN',
num_classes=10,
generator=dict(
type='BigGANGenerator',
output_scale=32,
noise_size=128,
num_classes=10,
base_channels=64,
with_shared_embedding=False,
sn_eps=1e-8,
init_type='N02',
split_noise=False,
auto_sync_bn=False),
discriminator=dict(
type='BigGANDiscriminator',
input_scale=32,
num_classes=10,
base_channels=64,
sn_eps=1e-8,
init_type='N02',
with_spectral_norm=True),
gan_loss=dict(type='GANLoss', gan_type='hinge'))

train_cfg = dict(
disc_steps=4, gen_steps=1, batch_accumulation_steps=1, use_ema=True)
test_cfg = None
optimizer = dict(
generator=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999)),
discriminator=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999)))
60 changes: 60 additions & 0 deletions configs/biggan/biggan_cifar10_32x32_b25x2_500k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
_base_ = [
'../_base_/models/biggan_32x32.py', '../_base_/datasets/cifar10_noaug.py',
'../_base_/default_runtime.py'
]

# define dataset
# you must set `samples_per_gpu`
data = dict(samples_per_gpu=25, workers_per_gpu=8)

# adjust running config
lr_config = None
checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=20)
custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interval=4,
start_iter=4000,
interp_cfg=dict(momentum=0.9999),
priority='VERY_HIGH')
]

total_iters = 500000

# use ddp wrapper for faster training
use_ddp_wrapper = True
find_unused_parameters = False

runner = dict(
type='DynamicIterBasedRunner',
is_dynamic_ddp=False, # Note that this flag should be False.
pass_training_status=True)

# Note set your inception_pkl's path
inception_pkl = None
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=[
dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
bgr2rgb=True),
dict(type='IS', num_images=50000)
],
sample_kwargs=dict(sample_model='ema'),
best_metric=['fid', 'is'])

metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
bgr2rgb=True),
is50k=dict(type='IS', num_images=50000))
78 changes: 74 additions & 4 deletions mmgen/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os.path as osp
import sys
import warnings
from bisect import bisect_right

import mmcv
import torch
Expand All @@ -15,7 +16,7 @@
class GenerativeEvalHook(Hook):
"""Evaluation Hook for Generative Models.
Currently, this evaluation hook can be used to evaluate unconditional
This evaluation hook can be used to evaluate unconditional and conditional
models. Note that only ``FID`` and ``IS`` metric is supported for the
distributed training now. In the future, we will support more metrics for
the evaluation during the training procedure.
Expand All @@ -25,6 +26,12 @@ class GenerativeEvalHook(Hook):
What you need to do is to add these lines at the end of your config file.
Then, you can use this evaluation hook in the training procedure.
To be noted that, this evaluation hook support evaluation with dynamic
intervals for FID or other metrics may fluctuate frequently at the end of
the training process.
# TODO: fix the online doc
#. Only use FID for evaluation
.. code-blcok:: python
Expand Down Expand Up @@ -58,9 +65,36 @@ class GenerativeEvalHook(Hook):
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
#. Use dynamic evaluation intervals
.. code-block:: python
:linenos
# interval = 10000 if iter < 50000,
# interval = 4000, if 50000 <= iter < 750000,
# interval = 2000, if iter >= 750000
evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[500000, 750000],
interval=[10000, 4000, 2000])
metrics=[dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
dict(type='IS',
num_images=50000)],
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
Args:
dataloader (DataLoader): A PyTorch dataloader.
interval (int): Evaluation interval. Default: 1.
interval (int | dict): Evaluation interval. If int is passed,
``eval_hook`` would run under given interval. If a dict is passed,
The key and value would be interpret as 'milestones' and 'interval'
of the evaluation. Default: 1.
dist (bool, optional): Whether to use distributed evaluation.
Defaults to True.
metrics (dict | list[dict], optional): Configs for metrics that will be
Expand Down Expand Up @@ -90,12 +124,39 @@ def __init__(self,
best_metric='fid'):
assert metrics is not None
self.dataloader = dataloader
self.interval = interval
self.dist = dist
self.sample_kwargs = sample_kwargs if sample_kwargs else dict()
self.save_best_ckpt = save_best_ckpt
self.best_metric = best_metric

if isinstance(interval, int):
self.interval = interval
elif isinstance(interval, dict):
if 'milestones' not in interval or 'interval' not in interval:
raise KeyError(
'`milestones` and `interval` must exist in interval dict '
'if you want to use the dynamic interval evaluation '
f'strategy. But receive [{[k for k in interval.keys()]}] '
'in the interval dict.')

self.milestones = interval['milestones']
self.interval = interval['interval']
# check if length of interval match with the milestones
if len(self.interval) != len(self.milestones) + 1:
raise ValueError(
f'Length of `interval`(={len(self.interval)}) cannot '
f'match length of `milestones`(={len(self.milestones)}).')

# check if milestones is in order
for idx in range(len(self.milestones) - 1):
former, latter = self.milestones[idx], self.milestones[idx + 1]
if former >= latter:
raise ValueError(
'Elements in `milestones` shoule in ascending order.')
else:
raise TypeError('`interval` only support `int` or `dict`,'
f'recieve {type(self.interval)} instead.')

if isinstance(best_metric, str):
self.best_metric = [self.best_metric]

Expand Down Expand Up @@ -129,6 +190,14 @@ def __init__(self,
self.rule[name]]
self._curr_best_ckpt_path[name] = None

def get_current_interval(self, runner):
if isinstance(self.interval, int):
return self.interval
else:
curr_iter = runner.iter + 1
index = bisect_right(self.milestones, curr_iter)
return self.interval[index]

def before_run(self, runner):
"""The behavior before running.
Expand All @@ -147,7 +216,8 @@ def after_train_iter(self, runner):
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
if not self.every_n_iters(runner, self.interval):
interval = self.get_current_interval(runner)
if not self.every_n_iters(runner, interval):
return

runner.model.eval()
Expand Down
4 changes: 2 additions & 2 deletions mmgen/models/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .biggan import BigGANGenerator
from .biggan import BigGANGenerator, SNConvModule
from .cyclegan import ResnetGenerator
from .dcgan import DCGANDiscriminator, DCGANGenerator
from .fid_inception import InceptionV3
Expand Down Expand Up @@ -31,5 +31,5 @@
'generation_init_weights', 'PatchDiscriminator', 'ResnetGenerator',
'PerceptualLoss', 'WGANGPDiscriminator', 'WGANGPGenerator',
'LSGANDiscriminator', 'LSGANGenerator', 'ProjDiscriminator',
'SNGANGenerator', 'BigGANGenerator'
'SNGANGenerator', 'BigGANGenerator', 'SNConvModule'
]
11 changes: 8 additions & 3 deletions mmgen/models/architectures/biggan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from .generator_discriminator import BigGANDiscriminator, BigGANGenerator
from .modules import (BigGANConditionBN, BigGANDiscResBlock, BigGANGenResBlock,
SelfAttentionBlock)
from .generator_discriminator_deep import (BigGANDeepDiscriminator,
BigGANDeepGenerator)
from .modules import (BigGANConditionBN, BigGANDeepDiscResBlock,
BigGANDeepGenResBlock, BigGANDiscResBlock,
BigGANGenResBlock, SelfAttentionBlock, SNConvModule)

__all__ = [
'BigGANGenerator', 'BigGANGenResBlock', 'BigGANConditionBN',
'BigGANDiscriminator', 'SelfAttentionBlock', 'BigGANDiscResBlock'
'BigGANDiscriminator', 'SelfAttentionBlock', 'BigGANDiscResBlock',
'BigGANDeepDiscriminator', 'BigGANDeepGenerator', 'BigGANDeepDiscResBlock',
'BigGANDeepGenResBlock', 'SNConvModule'
]
Loading

0 comments on commit f458fa7

Please sign in to comment.