diff --git a/configs/_base_/models/cyclegan_lsgan_resnet.py b/configs/_base_/models/cyclegan_lsgan_resnet.py index b953e1b43..38f907783 100644 --- a/configs/_base_/models/cyclegan_lsgan_resnet.py +++ b/configs/_base_/models/cyclegan_lsgan_resnet.py @@ -1,3 +1,5 @@ +_domain_a = None # set by user +_domain_b = None # set by user model = dict( type='CycleGAN', generator=dict( @@ -30,25 +32,32 @@ dict( type='L1Loss', loss_weight=10.0, - data_info=dict(pred='cycle_photo', target='real_photo'), + loss_name='cycle_loss', + data_info=dict( + pred=f'cycle_{_domain_a}', target=f'real_{_domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( - pred='cycle_mask', - target='real_mask', + pred=f'cycle_{_domain_b}', + target=f'real_{_domain_b}', ), reduction='mean'), dict( type='L1Loss', loss_weight=0.5, - data_info=dict(pred='identity_photo', target='real_photo'), + loss_name='id_loss', + data_info=dict( + pred=f'identity_{_domain_a}', target=f'real_{_domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=0.5, - data_info=dict(pred='identity_mask', target='real_mask'), + loss_name='id_loss', + data_info=dict( + pred=f'identity_{_domain_b}', target=f'real_{_domain_b}'), reduction='mean') ]) train_cfg = dict(buffer_size=50) diff --git a/configs/_base_/models/pix2pix_vanilla_unet_bn.py b/configs/_base_/models/pix2pix_vanilla_unet_bn.py index 13b1dfd3d..66e40ddf2 100644 --- a/configs/_base_/models/pix2pix_vanilla_unet_bn.py +++ b/configs/_base_/models/pix2pix_vanilla_unet_bn.py @@ -31,6 +31,7 @@ gen_auxiliary_loss=dict( type='L1Loss', loss_weight=100.0, + loss_name='pixel_loss', data_info=dict( pred=f'fake_{target_domain}', target=f'real_{target_domain}'), reduction='mean')) diff --git a/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_250k_summer2winter.py b/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_250k_summer2winter.py index 64807d4be..467a468b0 100644 --- a/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_250k_summer2winter.py +++ b/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_250k_summer2winter.py @@ -13,12 +13,14 @@ dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_b}', target=f'real_{domain_b}', @@ -74,7 +76,11 @@ optimizer = dict( generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) -lr_config = None + +# learning policy +lr_config = dict( + policy='Linear', by_epoch=False, target_lr=0, start=125000, interval=1250) + checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False) custom_hooks = [ dict( diff --git a/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_270k_horse2zebra.py b/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_270k_horse2zebra.py index 38acfbcb6..2dc7bcbd1 100644 --- a/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_270k_horse2zebra.py +++ b/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_270k_horse2zebra.py @@ -13,12 +13,14 @@ dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_b}', target=f'real_{domain_b}', @@ -74,7 +76,11 @@ optimizer = dict( generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) -lr_config = None + +# learning policy +lr_config = dict( + policy='Linear', by_epoch=False, target_lr=0, start=135000, interval=1350) + checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False) custom_hooks = [ dict( @@ -88,7 +94,7 @@ use_ddp_wrapper = True total_iters = 270000 workflow = [('train', 1)] -exp_name = 'cyclegan_facades_id0' +exp_name = 'cyclegan_horse2zebra_id0' work_dir = f'./work_dirs/experiments/{exp_name}' metrics = dict( FID=dict(type='FID', num_images=140, image_shape=(3, 256, 256)), diff --git a/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_80k_facades.py b/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_80k_facades.py index 477013f88..f711fd166 100644 --- a/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_80k_facades.py +++ b/configs/cyclegan/cyclegan_lsgan_id0_resnet_in_1x1_80k_facades.py @@ -15,12 +15,14 @@ dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_b}', target=f'real_{domain_b}', @@ -82,7 +84,11 @@ optimizer = dict( generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) -lr_config = None + +# learning policy +lr_config = dict( + policy='Linear', by_epoch=False, target_lr=0, start=40000, interval=400) + checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False) custom_hooks = [ dict( diff --git a/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_250k_summer2winter.py b/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_250k_summer2winter.py index 71a7625a1..c78ae4782 100644 --- a/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_250k_summer2winter.py +++ b/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_250k_summer2winter.py @@ -14,12 +14,14 @@ dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_b}', target=f'real_{domain_b}', @@ -28,12 +30,14 @@ dict( type='L1Loss', loss_weight=0.5, + loss_name='id_loss', data_info=dict( pred=f'identity_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=0.5, + loss_name='id_loss', data_info=dict( pred=f'identity_{domain_b}', target=f'real_{domain_b}'), reduction='mean') @@ -87,7 +91,11 @@ optimizer = dict( generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) -lr_config = None + +# learning policy +lr_config = dict( + policy='Linear', by_epoch=False, target_lr=0, start=125000, interval=1250) + checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False) custom_hooks = [ dict( diff --git a/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py b/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py index 99ddf9d74..7fac786fb 100644 --- a/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py +++ b/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_270k_horse2zebra.py @@ -14,12 +14,14 @@ dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_b}', target=f'real_{domain_b}', @@ -28,12 +30,14 @@ dict( type='L1Loss', loss_weight=0.5, + loss_name='id_loss', data_info=dict( pred=f'identity_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=0.5, + loss_name='id_loss', data_info=dict( pred=f'identity_{domain_b}', target=f'real_{domain_b}'), reduction='mean') @@ -87,7 +91,11 @@ optimizer = dict( generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) -lr_config = None + +# learning policy +lr_config = dict( + policy='Linear', by_epoch=False, target_lr=0, start=135000, interval=1350) + checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False) custom_hooks = [ dict( @@ -101,7 +109,7 @@ use_ddp_wrapper = True total_iters = 270000 workflow = [('train', 1)] -exp_name = 'cyclegan_facades_id0' +exp_name = 'cyclegan_horse2zebra' work_dir = f'./work_dirs/experiments/{exp_name}' # testA 120, testB 140 metrics = dict( diff --git a/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py b/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py index 32f0ccdeb..d9dedd8f4 100644 --- a/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py +++ b/configs/cyclegan/cyclegan_lsgan_resnet_in_1x1_80k_facades.py @@ -16,12 +16,14 @@ dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=10.0, + loss_name='cycle_loss', data_info=dict( pred=f'cycle_{domain_b}', target=f'real_{domain_b}', @@ -30,12 +32,14 @@ dict( type='L1Loss', loss_weight=0.5, + loss_name='id_loss', data_info=dict( pred=f'identity_{domain_a}', target=f'real_{domain_a}'), reduction='mean'), dict( type='L1Loss', loss_weight=0.5, + loss_name='id_loss', data_info=dict( pred=f'identity_{domain_b}', target=f'real_{domain_b}'), reduction='mean') @@ -49,7 +53,11 @@ optimizer = dict( generators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), discriminators=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) -lr_config = None + +# learning policy +lr_config = dict( + policy='Linear', by_epoch=False, target_lr=0, start=40000, interval=400) + checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False) custom_hooks = [ dict( diff --git a/configs/pix2pix/README.md b/configs/pix2pix/README.md index 53ec8966a..ba930674f 100644 --- a/configs/pix2pix/README.md +++ b/configs/pix2pix/README.md @@ -25,21 +25,21 @@ We use `FID` and `IS` metrics to evaluate the generation performance of pix2pix. `FID` evaluation: -| Dataset | [facades](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades.py) | [maps-a2b](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_a2b_1x1_220k_maps.py) | [maps-b2a](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_b2a_1x1_220k_maps.py) | [edges2shoes](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_190k_edges2shoes.py) | average | +| Dataset | [facades](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades.py) | [aerial2maps](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_aerial2maps.py) | [maps2aerial](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_maps2aerial.py) | [edges2shoes](https://github.com/open-mmlab/mmgeneration/tree/master/configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_190k_edges2shoes.py) | average | | :------: | :--------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------: | :----------: | | official | **119.135** | 149.731 | 102.072 | **75.774** | 111.678 | | ours | 124.372 | **122.691** | **88.378** | 85.144 | **105.1463** | `IS` evaluation: -| Dataset | facades | maps-a2b | maps-b2a | edges2shoes | average | +| Dataset | facades | aerial2maps | maps2aerial | edges2shoes | average | | :------: | :-------: | :-------: | :-------: | :---------: | :-------: | | official | 1.650 | 2.529 | 3.552 | 2.766 | 2.624 | | ours | **1.665** | **3.337** | **3.585** | **2.797** | **2.846** | Model and log downloads: -| Dataset | facades | maps-a2b | maps-b2a | edges2shoes | +| Dataset | facades | aerial2maps | maps2aerial | edges2shoes | | :------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------: | | download | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210902_170442-c0958d50.pth?versionId=CAEQMhiBgICb8fTj3RciIGU2NmViM2QyYzJkODQ0MDBhYTFhMGE2YzNmMTA0ODk3) \| [log](https://download.openmmlab.com/mmgen/pix2pix/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210317_172625.log.json)2 | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_a2b_1x1_219200_maps_convert-bgr_20210902_170729-59a31517.pth?versionId=CAEQMhiBgICH9vTj3RciIDdiNGRmYTNlZjhlMjQ0ODc4OTJiOGEzMjY0YTJlZmQ5) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_b2a_1x1_219200_maps_convert-bgr_20210902_170814-6d2eac4a.pth?versionId=CAEQMhiBgMC08PTj3RciIGE4ODVkZWU0MTYyMTQ0MWJhZjE0YThmY2M2NDJmZjNi) | [model](https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth?versionId=CAEQMhiBgIC57vTj3RciIGZlNmQ4ZDJhN2E1MDQ5ZmJiOWJmYTY5MDg1ZTc0N2Vi) | diff --git a/configs/pix2pix/pix2pix_vanilla_unet_bn_a2b_1x1_220k_maps.py b/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_aerial2maps.py similarity index 99% rename from configs/pix2pix/pix2pix_vanilla_unet_bn_a2b_1x1_220k_maps.py rename to configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_aerial2maps.py index dd3857d1e..9c1db49f8 100644 --- a/configs/pix2pix/pix2pix_vanilla_unet_bn_a2b_1x1_220k_maps.py +++ b/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_aerial2maps.py @@ -105,7 +105,7 @@ # runtime settings total_iters = 220000 workflow = [('train', 1)] -exp_name = 'pix2pix_maps_a2b' +exp_name = 'pix2pix_aerial2map' work_dir = f'./work_dirs/experiments/{exp_name}' metrics = dict( FID=dict(type='FID', num_images=1098, image_shape=(3, 256, 256)), diff --git a/configs/pix2pix/pix2pix_vanilla_unet_bn_b2a_1x1_220k_maps.py b/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_maps2aerial.py similarity index 99% rename from configs/pix2pix/pix2pix_vanilla_unet_bn_b2a_1x1_220k_maps.py rename to configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_maps2aerial.py index 131744db4..f9dcda65f 100644 --- a/configs/pix2pix/pix2pix_vanilla_unet_bn_b2a_1x1_220k_maps.py +++ b/configs/pix2pix/pix2pix_vanilla_unet_bn_1x1_220k_maps2aerial.py @@ -104,7 +104,7 @@ # runtime settings total_iters = 220000 workflow = [('train', 1)] -exp_name = 'pix2pix_maps_b2a' +exp_name = 'pix2pix_maps2aerial' work_dir = f'./work_dirs/experiments/{exp_name}' metrics = dict( FID=dict(type='FID', num_images=1098, image_shape=(3, 256, 256)), diff --git a/mmgen/apis/inference.py b/mmgen/apis/inference.py index 24761de78..3e914d35f 100644 --- a/mmgen/apis/inference.py +++ b/mmgen/apis/inference.py @@ -192,6 +192,7 @@ def sample_img2img_model(model, image_path, target_domain, **kwargs): Tensor: Translated image tensor. """ assert isinstance(model, BaseTranslationModel) + assert target_domain in model._reachable_domains cfg = model._cfg device = next(model.parameters()).device # model device # build the data pipeline diff --git a/tests/test_models/test_pix2pix.py b/tests/test_models/test_pix2pix.py index 9d5bfeef0..eb76dc5c2 100644 --- a/tests/test_models/test_pix2pix.py +++ b/tests/test_models/test_pix2pix.py @@ -249,33 +249,3 @@ def test_pix2pix(): data_batch['img_photo']) assert torch.is_tensor(outputs['results']['fake_photo']) assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256) - - # test b2a translation - data_batch['img_mask'] = img_mask.cpu() - data_batch['img_photo'] = img_photo.cpu() - train_cfg = dict(direction='b2a') - synthesizer = build_model( - model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) - optimizer = { - 'generators': - obj_from_dict( - optim_cfg, torch.optim, - dict(params=getattr(synthesizer, 'generators').parameters())), - 'discriminators': - obj_from_dict( - optim_cfg, torch.optim, - dict(params=getattr(synthesizer, 'discriminators').parameters())) - } - outputs = synthesizer.train_step(data_batch, optimizer) - assert isinstance(outputs, dict) - assert isinstance(outputs['log_vars'], dict) - assert isinstance(outputs['results'], dict) - for v in ['loss_gan_d_fake', 'loss_gan_d_real', 'loss_gan_g', 'loss_l1']: - assert isinstance(outputs['log_vars'][v], float) - assert outputs['num_samples'] == 1 - assert torch.equal(outputs['results']['real_mask'], data_batch['img_mask']) - assert torch.equal(outputs['results']['real_photo'], - data_batch['img_photo']) - assert torch.is_tensor(outputs['results']['fake_photo']) - assert outputs['results']['fake_photo'].size() == (1, 3, 256, 256) - assert synthesizer.iteration == 1