Skip to content

Commit

Permalink
[Fix] Fix README and configs of translation models (#113)
Browse files Browse the repository at this point in the history
* fix names and lr schedule

* fix lint

Co-authored-by: yangyifei <PJLAB\yangyifei@shai14001042l.pjlab.org>
  • Loading branch information
plyfager and yangyifei committed Sep 15, 2021
1 parent 0032271 commit a9c0ac4
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 48 deletions.
19 changes: 14 additions & 5 deletions configs/_base_/models/cyclegan_lsgan_resnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
_domain_a = None # set by user
_domain_b = None # set by user
model = dict(
type='CycleGAN',
generator=dict(
Expand Down Expand Up @@ -30,25 +32,32 @@
dict(
type='L1Loss',
loss_weight=10.0,
data_info=dict(pred='cycle_photo', target='real_photo'),
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{_domain_a}', target=f'real_{_domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred='cycle_mask',
target='real_mask',
pred=f'cycle_{_domain_b}',
target=f'real_{_domain_b}',
),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=0.5,
data_info=dict(pred='identity_photo', target='real_photo'),
loss_name='id_loss',
data_info=dict(
pred=f'identity_{_domain_a}', target=f'real_{_domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=0.5,
data_info=dict(pred='identity_mask', target='real_mask'),
loss_name='id_loss',
data_info=dict(
pred=f'identity_{_domain_b}', target=f'real_{_domain_b}'),
reduction='mean')
])
train_cfg = dict(buffer_size=50)
Expand Down
1 change: 1 addition & 0 deletions configs/_base_/models/pix2pix_vanilla_unet_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
gen_auxiliary_loss=dict(
type='L1Loss',
loss_weight=100.0,
loss_name='pixel_loss',
data_info=dict(
pred=f'fake_{target_domain}', target=f'real_{target_domain}'),
reduction='mean'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_b}',
target=f'real_{domain_b}',
Expand Down Expand Up @@ -74,7 +76,11 @@
optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
lr_config = None

# learning policy
lr_config = dict(
policy='Linear', by_epoch=False, target_lr=0, start=125000, interval=1250)

checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False)
custom_hooks = [
dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_b}',
target=f'real_{domain_b}',
Expand Down Expand Up @@ -74,7 +76,11 @@
optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
lr_config = None

# learning policy
lr_config = dict(
policy='Linear', by_epoch=False, target_lr=0, start=135000, interval=1350)

checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False)
custom_hooks = [
dict(
Expand All @@ -88,7 +94,7 @@
use_ddp_wrapper = True
total_iters = 270000
workflow = [('train', 1)]
exp_name = 'cyclegan_facades_id0'
exp_name = 'cyclegan_horse2zebra_id0'
work_dir = f'./work_dirs/experiments/{exp_name}'
metrics = dict(
FID=dict(type='FID', num_images=140, image_shape=(3, 256, 256)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_b}',
target=f'real_{domain_b}',
Expand Down Expand Up @@ -82,7 +84,11 @@
optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
lr_config = None

# learning policy
lr_config = dict(
policy='Linear', by_epoch=False, target_lr=0, start=40000, interval=400)

checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False)
custom_hooks = [
dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_b}',
target=f'real_{domain_b}',
Expand All @@ -28,12 +30,14 @@
dict(
type='L1Loss',
loss_weight=0.5,
loss_name='id_loss',
data_info=dict(
pred=f'identity_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=0.5,
loss_name='id_loss',
data_info=dict(
pred=f'identity_{domain_b}', target=f'real_{domain_b}'),
reduction='mean')
Expand Down Expand Up @@ -87,7 +91,11 @@
optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
lr_config = None

# learning policy
lr_config = dict(
policy='Linear', by_epoch=False, target_lr=0, start=125000, interval=1250)

checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False)
custom_hooks = [
dict(
Expand Down
12 changes: 10 additions & 2 deletions configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_b}',
target=f'real_{domain_b}',
Expand All @@ -28,12 +30,14 @@
dict(
type='L1Loss',
loss_weight=0.5,
loss_name='id_loss',
data_info=dict(
pred=f'identity_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=0.5,
loss_name='id_loss',
data_info=dict(
pred=f'identity_{domain_b}', target=f'real_{domain_b}'),
reduction='mean')
Expand Down Expand Up @@ -87,7 +91,11 @@
optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
lr_config = None

# learning policy
lr_config = dict(
policy='Linear', by_epoch=False, target_lr=0, start=135000, interval=1350)

checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False)
custom_hooks = [
dict(
Expand All @@ -101,7 +109,7 @@
use_ddp_wrapper = True
total_iters = 270000
workflow = [('train', 1)]
exp_name = 'cyclegan_facades_id0'
exp_name = 'cyclegan_horse2zebra'
work_dir = f'./work_dirs/experiments/{exp_name}'
# testA 120, testB 140
metrics = dict(
Expand Down
10 changes: 9 additions & 1 deletion configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
loss_name='cycle_loss',
data_info=dict(
pred=f'cycle_{domain_b}',
target=f'real_{domain_b}',
Expand All @@ -30,12 +32,14 @@
dict(
type='L1Loss',
loss_weight=0.5,
loss_name='id_loss',
data_info=dict(
pred=f'identity_{domain_a}', target=f'real_{domain_a}'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=0.5,
loss_name='id_loss',
data_info=dict(
pred=f'identity_{domain_b}', target=f'real_{domain_b}'),
reduction='mean')
Expand All @@ -49,7 +53,11 @@
optimizer = dict(
generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
lr_config = None

# learning policy
lr_config = dict(
policy='Linear', by_epoch=False, target_lr=0, start=40000, interval=400)

checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False)
custom_hooks = [
dict(
Expand Down
6 changes: 3 additions & 3 deletions configs/pix2pix/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ We use `FID` and `IS` metrics to evaluate the generation performance of pix2pix.

`FID` evaluation:

| Dataset | [facades](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades.py) | [maps-a2b](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_a2b_1x1_220k_maps.py) | [maps-b2a](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_b2a_1x1_220k_maps.py) | [edges2shoes](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_190k_edges2shoes.py) | average |
| Dataset | [facades](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades.py) | [aerial2maps](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_aerial2maps.py) | [maps2aerial](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_maps2aerial.py) | [edges2shoes](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_190k_edges2shoes.py) | average |
| :------: | :--------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------: | :----------: |
| official | **119.135** | 149.731 | 102.072 | **75.774** | 111.678 |
| ours | 124.372 | **122.691** | **88.378** | 85.144 | **105.1463** |

`IS` evaluation:

| Dataset | facades | maps-a2b | maps-b2a | edges2shoes | average |
| Dataset | facades | aerial2maps | maps2aerial | edges2shoes | average |
| :------: | :-------: | :-------: | :-------: | :---------: | :-------: |
| official | 1.650 | 2.529 | 3.552 | 2.766 | 2.624 |
| ours | **1.665** | **3.337** | **3.585** | **2.797** | **2.846** |

Model and log downloads:

| Dataset | facades | maps-a2b | maps-b2a | edges2shoes |
| Dataset | facades | aerial2maps | maps2aerial | edges2shoes |
| :------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------: |
| download | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210902_170442-c0958d50.pth?versionId=CAEQMhiBgICb8fTj3RciIGU2NmViM2QyYzJkODQ0MDBhYTFhMGE2YzNmMTA0ODk3) \| [log](https://download.openmmlab.com/mmgen/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210317_172625.log.json)<sup>2</sup> | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_a2b_1x1_219200_maps_convert-bgr_20210902_170729-59a31517.pth?versionId=CAEQMhiBgICH9vTj3RciIDdiNGRmYTNlZjhlMjQ0ODc4OTJiOGEzMjY0YTJlZmQ5) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_b2a_1x1_219200_maps_convert-bgr_20210902_170814-6d2eac4a.pth?versionId=CAEQMhiBgMC08PTj3RciIGE4ODVkZWU0MTYyMTQ0MWJhZjE0YThmY2M2NDJmZjNi) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth?versionId=CAEQMhiBgIC57vTj3RciIGZlNmQ4ZDJhN2E1MDQ5ZmJiOWJmYTY5MDg1ZTc0N2Vi) |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
# runtime settings
total_iters = 220000
workflow = [('train', 1)]
exp_name = 'pix2pix_maps_a2b'
exp_name = 'pix2pix_aerial2map'
work_dir = f'./work_dirs/experiments/{exp_name}'
metrics = dict(
FID=dict(type='FID', num_images=1098, image_shape=(3, 256, 256)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
# runtime settings
total_iters = 220000
workflow = [('train', 1)]
exp_name = 'pix2pix_maps_b2a'
exp_name = 'pix2pix_maps2aerial'
work_dir = f'./work_dirs/experiments/{exp_name}'
metrics = dict(
FID=dict(type='FID', num_images=1098, image_shape=(3, 256, 256)),
Expand Down
1 change: 1 addition & 0 deletions mmgen/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def sample_img2img_model(model, image_path, target_domain, **kwargs):
Tensor: Translated image tensor.
"""
assert isinstance(model, BaseTranslationModel)
assert target_domain in model._reachable_domains
cfg = model._cfg
device = next(model.parameters()).device # model device
# build the data pipeline
Expand Down
30 changes: 0 additions & 30 deletions tests/test_models/test_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,33 +249,3 @@ def test_pix2pix():
data_batch['img_photo'])
assert torch.is_tensor(outputs['results']['fake_photo'])
assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256)

# test b2a translation
data_batch['img_mask'] = img_mask.cpu()
data_batch['img_photo'] = img_photo.cpu()
train_cfg = dict(direction='b2a')
synthesizer = build_model(
model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)
optimizer = {
'generators':
obj_from_dict(
optim_cfg, torch.optim,
dict(params=getattr(synthesizer, 'generators').parameters())),
'discriminators':
obj_from_dict(
optim_cfg, torch.optim,
dict(params=getattr(synthesizer, 'discriminators').parameters()))
}
outputs = synthesizer.train_step(data_batch, optimizer)
assert isinstance(outputs, dict)
assert isinstance(outputs['log_vars'], dict)
assert isinstance(outputs['results'], dict)
for v in ['loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_l1']:
assert isinstance(outputs['log_vars'][v], float)
assert outputs['num_samples'] == 1
assert torch.equal(outputs['results']['real_mask'], data_batch['img_mask'])
assert torch.equal(outputs['results']['real_photo'],
data_batch['img_photo'])
assert torch.is_tensor(outputs['results']['fake_photo'])
assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256)
assert synthesizer.iteration == 1

0 comments on commit a9c0ac4

Please sign in to comment.