Skip to content

Commit

Permalink
[Feature] Support styleclip (#236)
Browse files Browse the repository at this point in the history
* support styleclip

* fix lint

* add clip to requirement

* fix lint

* fix runtime.txt

* fix runtime.txt

* complete unittest

* remove third party repo

* fix lint

* fix docstring

* move clip import into init function

* fix lint

* remove a unittest
  • Loading branch information
plyfager committed Jan 26, 2022
1 parent 40376b4 commit fc6f842
Show file tree
Hide file tree
Showing 11 changed files with 961 additions and 5 deletions.
213 changes: 213 additions & 0 deletions apps/styleclip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import argparse
import math
import os

import clip
import mmcv
import torch
import torchvision
from mmcv import Config, DictAction
from torch import optim
from tqdm import tqdm

from mmgen.apis import init_model
from mmgen.models.losses import CLIPLoss, FaceIdLoss

from mmgen.apis import set_random_seed # isort:skip # noqa


def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)

return initial_lr * lr_ramp


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='model config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--use-cpu',
action='store_true',
help='whether to use cpu device for sampling')
parser.add_argument(
'--description',
type=str,
default='a person with purple hair',
help='the text that guides the editing/generation')
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument(
'--mode',
type=str,
default='generate',
choices=['edit', 'generate'],
help='choose between edit an image an generate a free one')
parser.add_argument(
'--l2-lambda',
type=float,
default=0.008,
help='weight of the latent distance, used for editing only')
parser.add_argument(
'--id-lambda',
type=float,
default=0.000,
help='weight of id loss, used for editing only')
parser.add_argument(
'--proj-latent',
type=str,
default=None,
help='Projection image files produced by stylegan_projector.py. If this \
argument is given, then the projected latent will be used as the init\
latent.')
parser.add_argument(
'--truncation',
type=float,
default=0.7,
help='used only for the initial latent vector, and only when a latent '
'code path is not provided')
parser.add_argument(
'--step', type=int, default=2000, help='Optimization iterations')
parser.add_argument(
'--save_intermediate_image_every',
type=int,
default=20,
help='if > 0 then saves intermidate results during the optimization')
parser.add_argument(
'--results_dir', type=str, default='work_dirs/styleclip/')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')

args = parser.parse_args()
return args


def main():
args = parse_args()
# set cudnn_benchmark
cfg = Config.fromfile(args.config)
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True

# set random seeds
if args.seed is not None:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)

os.makedirs(args.results_dir, exist_ok=True)

text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()

model = init_model(args.config, args.checkpoint, device='cpu')
g_ema = model.generator_ema
g_ema.eval()
if not args.use_cpu:
g_ema = g_ema.cuda()

mean_latent = g_ema.get_mean_latent()

# if given proj_latent
if args.proj_latent is not None:
mmcv.print_log(f'Load projected latent: {args.proj_latent}', 'mmgen')
proj_file = torch.load(args.proj_latent)
proj_n = len(proj_file)
assert proj_n == 1
noise_batch = []
for img_path in proj_file:
noise_batch.append(proj_file[img_path]['latent'].unsqueeze(0))
latent_code_init = torch.cat(noise_batch, dim=0).cuda()
elif args.mode == 'edit':
latent_code_init_not_trunc = torch.randn(1, 512).cuda()
with torch.no_grad():
results = g_ema([latent_code_init_not_trunc],
return_latents=True,
truncation=args.truncation,
truncation_latent=mean_latent)
latent_code_init = results['latent']
else:
latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

with torch.no_grad():
img_orig = g_ema([latent_code_init],
input_is_latent=True,
randomize_noise=False)

latent = latent_code_init.detach().clone()
latent.requires_grad = True

clip_loss = CLIPLoss(clip_model=dict(in_size=g_ema.out_size))
id_loss = FaceIdLoss(
facenet=dict(type='ArcFace', ir_se50_weights=None, device='cuda'))

optimizer = optim.Adam([latent], lr=args.lr)

pbar = tqdm(range(args.step))
mmcv.print_log(f'Description: {args.description}')
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]['lr'] = lr

img_gen = g_ema([latent], input_is_latent=True, randomize_noise=False)

img_gen = img_gen[:, [2, 1, 0], ...]

# clip loss
c_loss = clip_loss(image=img_gen, text=text_inputs)

if args.id_lambda > 0:
i_loss = id_loss(pred=img_gen, gt=img_orig)[0]
else:
i_loss = 0

if args.mode == 'edit':
l2_loss = ((latent_code_init - latent)**2).sum()
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
else:
loss = c_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()

pbar.set_description((f'loss: {loss.item():.4f};'))
if args.save_intermediate_image_every > 0 and (
i % args.save_intermediate_image_every == 0):
with torch.no_grad():
img_gen = g_ema([latent],
input_is_latent=True,
randomize_noise=False)

img_gen = img_gen[:, [2, 1, 0], ...]

torchvision.utils.save_image(
img_gen,
os.path.join(args.results_dir, f'{str(i).zfill(5)}.png'),
normalize=True,
range=(-1, 1))

if args.mode == 'edit':
img_orig = img_orig[:, [2, 1, 0], ...]
final_result = torch.cat([img_orig, img_gen])
else:
final_result = img_gen

torchvision.utils.save_image(
final_result.detach().cpu(),
os.path.join(args.results_dir, 'final_result.png'),
normalize=True,
scale_each=True,
range=(-1, 1))


if __name__ == '__main__':
main()
4 changes: 3 additions & 1 deletion mmgen/models/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .arcface import IDLossModel
from .biggan import (BigGANDeepDiscriminator, BigGANDeepGenerator,
BigGANDiscriminator, BigGANGenerator, SNConvModule)
from .cyclegan import ResnetGenerator
Expand Down Expand Up @@ -35,5 +36,6 @@
'PerceptualLoss', 'WGANGPDiscriminator', 'WGANGPGenerator',
'LSGANDiscriminator', 'LSGANGenerator', 'ProjDiscriminator',
'SNGANGenerator', 'BigGANGenerator', 'SNConvModule', 'BigGANDiscriminator',
'BigGANDeepGenerator', 'BigGANDeepDiscriminator', 'DenoisingUnet'
'BigGANDeepGenerator', 'BigGANDeepDiscriminator', 'DenoisingUnet',
'IDLossModel'
]
3 changes: 3 additions & 0 deletions mmgen/models/architectures/arcface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .id_loss import IDLossModel

__all__ = ['IDLossModel']
Loading

0 comments on commit fc6f842

Please sign in to comment.