Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/open-mmlab/mmgeneration i…
Browse files Browse the repository at this point in the history
…nto fix-wflow
  • Loading branch information
yangyifei authored and yangyifei committed Sep 12, 2021
2 parents 3fffe1c + bfd00c3 commit d4358dc
Show file tree
Hide file tree
Showing 52 changed files with 1,902 additions and 1,719 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ jobs:
torch_version: torch1.8.0
torchvision: 0.9.0+cu101
mmcv: "latest+torch1.8.0+cu101"
- torch: 1.8.0+cu101
torch_version: torch1.8.0
torchvision: 0.9.0+cu101
mmcv: "latest+torch1.8.0+cu101"
python-version: 3.9

steps:
- uses: actions/checkout@v2
Expand Down
8 changes: 8 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- name: "MMGeneration Contributors"
title: "OpenMMLab's next-generation toolbox for generative models"
date-released: 2020-07-10
url: "https://github.com/open-mmlab/mmgeneration"
license: Apache-2.0
42 changes: 27 additions & 15 deletions configs/_base_/datasets/paired_imgs_256x256_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,65 @@
train_dataset_type = 'PairedImageDataset'
val_dataset_type = 'PairedImageDataset'
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
domain_a = 'photo'
domain_b = 'mask'
train_pipeline = [
dict(
type='LoadPairedImageFromFile',
io_backend='disk',
key='pair',
domain_a=domain_a,
domain_b=domain_b,
flag='color'),
dict(
type='Resize',
keys=['img_a', 'img_b'],
keys=[f'img_{domain_a}', f'img_{domain_b}'],
scale=(286, 286),
interpolation='bicubic'),
dict(type='FixedCrop', keys=['img_a', 'img_b'], crop_size=(256, 256)),
dict(type='Flip', keys=['img_a', 'img_b'], direction='horizontal'),
dict(type='RescaleToZeroOne', keys=['img_a', 'img_b']),
dict(
type='FixedCrop',
keys=[f'img_{domain_a}', f'img_{domain_b}'],
crop_size=(256, 256)),
dict(
type='Flip',
keys=[f'img_{domain_a}', f'img_{domain_b}'],
direction='horizontal'),
dict(type='RescaleToZeroOne', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Normalize',
keys=['img_a', 'img_b'],
keys=[f'img_{domain_a}', f'img_{domain_b}'],
to_rgb=False,
**img_norm_cfg),
dict(type='ImageToTensor', keys=['img_a', 'img_b']),
dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Collect',
keys=['img_a', 'img_b'],
meta_keys=['img_a_path', 'img_b_path'])
keys=[f'img_{domain_a}', f'img_{domain_b}'],
meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
]
test_pipeline = [
dict(
type='LoadPairedImageFromFile',
io_backend='disk',
key='pair',
key='image',
domain_a=domain_a,
domain_b=domain_b,
flag='color'),
dict(
type='Resize',
keys=['img_a', 'img_b'],
keys=[f'img_{domain_a}', f'img_{domain_b}'],
scale=(256, 256),
interpolation='bicubic'),
dict(type='RescaleToZeroOne', keys=['img_a', 'img_b']),
dict(type='RescaleToZeroOne', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Normalize',
keys=['img_a', 'img_b'],
keys=[f'img_{domain_a}', f'img_{domain_b}'],
to_rgb=False,
**img_norm_cfg),
dict(type='ImageToTensor', keys=['img_a', 'img_b']),
dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Collect',
keys=['img_a', 'img_b'],
meta_keys=['img_a_path', 'img_b_path'])
keys=[f'img_{domain_a}', f'img_{domain_b}'],
meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
]

data = dict(
Expand Down
58 changes: 32 additions & 26 deletions configs/_base_/datasets/unpaired_imgs_256x256.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
train_dataset_type = 'UnpairedImageDataset'
val_dataset_type = 'UnpairedImageDataset'
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
domain_a = None # set by user
domain_b = None # set by user
train_pipeline = [
dict(
type='LoadImageFromFile', io_backend='disk', key='img_a',
type='LoadImageFromFile',
io_backend='disk',
key=f'img_{domain_a}',
flag='color'),
dict(
type='LoadImageFromFile', io_backend='disk', key='img_b',
type='LoadImageFromFile',
io_backend='disk',
key=f'img_{domain_b}',
flag='color'),
dict(
type='Resize',
keys=['img_a', 'img_b'],
keys=[f'img_{domain_a}', f'img_{domain_b}'],
scale=(286, 286),
interpolation='bicubic'),
dict(
type='Crop',
keys=['img_a', 'img_b'],
keys=[f'img_{domain_a}', f'img_{domain_b}'],
crop_size=(256, 256),
random_crop=True),
dict(type='Flip', keys=['img_a'], direction='horizontal'),
dict(type='Flip', keys=['img_b'], direction='horizontal'),
dict(type='RescaleToZeroOne', keys=['img_a', 'img_b']),
dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'),
dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'),
dict(type='RescaleToZeroOne', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Normalize',
keys=['img_a', 'img_b'],
keys=[f'img_{domain_a}', f'img_{domain_b}'],
to_rgb=False,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
dict(type='ImageToTensor', keys=['img_a', 'img_b']),
dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']),
dict(
type='Collect',
keys=['img_a', 'img_b'],
meta_keys=['img_a_path', 'img_b_path'])
keys=[f'img_{domain_a}', f'img_{domain_b}'],
meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path'])
]
test_pipeline = [
dict(
type='LoadImageFromFile', io_backend='disk', key='img_a',
flag='color'),
dict(
type='LoadImageFromFile', io_backend='disk', key='img_b',
type='LoadImageFromFile', io_backend='disk', key='image',
flag='color'),
dict(
type='Resize',
keys=['img_a', 'img_b'],
keys=['image'],
scale=(256, 256),
interpolation='bicubic'),
dict(type='RescaleToZeroOne', keys=['img_a', 'img_b']),
dict(type='RescaleToZeroOne', keys=['image']),
dict(
type='Normalize',
keys=['img_a', 'img_b'],
keys=['image'],
to_rgb=False,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
dict(type='ImageToTensor', keys=['img_a', 'img_b']),
dict(
type='Collect',
keys=['img_a', 'img_b'],
meta_keys=['img_a_path', 'img_b_path'])
dict(type='ImageToTensor', keys=['image']),
dict(type='Collect', keys=['image'], meta_keys=['image_path'])
]
data_root = None
data = dict(
Expand All @@ -69,14 +69,20 @@
type=train_dataset_type,
dataroot=data_root,
pipeline=train_pipeline,
test_mode=False),
test_mode=False,
domain_a=domain_a,
domain_b=domain_b),
val=dict(
type=val_dataset_type,
dataroot=data_root,
pipeline=test_pipeline,
test_mode=True),
test_mode=True,
domain_a=domain_a,
domain_b=domain_b),
test=dict(
type=val_dataset_type,
dataroot=data_root,
pipeline=test_pipeline,
test_mode=True))
test_mode=True,
domain_a=domain_a,
domain_b=domain_b))
34 changes: 30 additions & 4 deletions configs/_base_/models/cyclegan_lsgan_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,33 @@
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1.0),
cycle_loss=dict(type='L1Loss', loss_weight=10.0, reduction='mean'),
id_loss=dict(type='L1Loss', loss_weight=0.5, reduction='mean'))
train_cfg = dict(direction='a2b', buffer_size=50)
test_cfg = dict(direction='a2b', show_input=True)
default_domain=None, # set by user
reachable_domains=None, # set by user
related_domains=None, # set by user
gen_auxiliary_loss=[
dict(
type='L1Loss',
loss_weight=10.0,
data_info=dict(pred='cycle_photo', target='real_photo'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=10.0,
data_info=dict(
pred='cycle_mask',
target='real_mask',
),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=0.5,
data_info=dict(pred='identity_photo', target='real_photo'),
reduction='mean'),
dict(
type='L1Loss',
loss_weight=0.5,
data_info=dict(pred='identity_mask', target='real_mask'),
reduction='mean')
])
train_cfg = dict(buffer_size=50)
test_cfg = None
16 changes: 13 additions & 3 deletions configs/_base_/models/pix2pix_vanilla_unet_bn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
source_domain = None # set by user
target_domain = None # set by user
# model settings
model = dict(
type='Pix2Pix',
Expand All @@ -23,7 +25,15 @@
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1.0),
pixel_loss=dict(type='L1Loss', loss_weight=100.0, reduction='mean'))
default_domain=target_domain,
reachable_domains=[target_domain],
related_domains=[target_domain, source_domain],
gen_auxiliary_loss=dict(
type='L1Loss',
loss_weight=100.0,
data_info=dict(
pred=f'fake_{target_domain}', target=f'real_{target_domain}'),
reduction='mean'))
# model training and testing settings
train_cfg = dict(direction='a2b') # model default: a2b
test_cfg = dict(direction='a2b', show_input=True)
train_cfg = None
test_cfg = None
7 changes: 3 additions & 4 deletions configs/biggan/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS

# Large Scale GAN Training for High Fidelity Natural Image Synthesis
## Introduction
<!-- [ALGORITHM] -->
```latex
Expand Down Expand Up @@ -29,10 +28,10 @@ We have finished training `BigGAN` in `Cifar10` (32x32) and are aligning trainin
</div>

Evaluation of our trained BIgGAN.
| Models | Dataset | FID (Iter) | IS (Iter) | Config | Download |
| Models | Dataset | Best FID (Iter) | Best IS (Iter) | Config | Download |
|:------------:|:-------:|:--------------:|:---------------:|:------:|:----------:|
| BigGAN 32x32 | CIFAR10 | 9.78(390000) | 8.70(390000) | [config](https://github.com/open-mmlab/mmgeneration/blob/master/configs/biggan/biggan_cifar10_32x32_b25x2_500k.py) | [model](https://download.openmmlab.com/mmgen/biggan/biggan_cifar10_32x32_b25x2_500k_20210728_110906-08b61a44.pth)\|[log](https://download.openmmlab.com/mmgen/biggan/biggan_cifar10_32_b25x2_500k_20210706_171051.log.json) |
| BigGAN 128x128 | ImageNet1k | 12.32(1150000) | 72.7(1150000) | [config](https://github.com/open-mmlab/mmgeneration/blob/master/configs/biggan/biggan_imagenet1k_128x128_b32x8_1500k.py) | [model](https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_1150k_20210730_124753-b14026b7.pth)\|[log](https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_1500k_20210726_224316.log.json) |
| BigGAN 128x128 | ImageNet1k | 10.02(1449000) | 86.8(1449000) | [config](https://github.com/open-mmlab/mmgeneration/blob/master/configs/biggan/biggan_imagenet1k_128x128_b32x8_1500k.py) | [model](https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_1449000_20210906_141519-f9128faf.pth?versionId=CAEQMhiBgIDn.ILn3RciIDFiMTZhZjIxYzA2MjQxMTJiMDQzZjQyNWQ5YTVkY2Jl)\|[log](https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_1449000_20210906_141519-f9128faf.log.json?versionId=CAEQMhiBgMD5.dfo3RciIDE5OThmOWM0ZDAxYjQxMWJiNTEzMTFlMzAxZWJkOGFm) |

Note: This is an unfinished version (1150k iter) of BigGAN trained on `ImageNet1k`. The model with the best performance is still on the way.

Expand Down
Loading

0 comments on commit d4358dc

Please sign in to comment.