Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mchong6 committed Nov 2, 2021
1 parent f4ecdfe commit 0b4f59a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 99 deletions.
12 changes: 12 additions & 0 deletions infinity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,18 @@
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
Expand Down
104 changes: 7 additions & 97 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,101 +538,6 @@ def get_latent(self, input, is_latent=False, truncation=1, mean_latent=None):

return output

def blend_bbox(self, latent1, latent2, coord):
def get_bbox_from_mask(img):
img = img[0,0]
a = torch.where(img != 0)

y = torch.min(a[0])
x = torch.min(a[1])
h = torch.max(a[0]) - y
w = torch.max(a[1]) - x

return (y,x,h,w)

noise = [getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)]

coord = coord.astype('uint8')
x1, y1, w1, h1 = coord[0]
x2, y2, w2, h2 = coord[1]
h = max(h1, h2)
w = max(w1, w2)

mask1 = torch.zeros([1,1,256,256]).cuda()
mask1[..., y1:y1+h, x1:x1+w] = 1
mask1 = k.gaussian_blur2d(mask1, (21,21), sigma=(10,10))

mask2 = torch.zeros([1,1,256,256]).cuda()
mask2[..., y2:y2+h, x2:x2+w] = 1
mask2 = k.gaussian_blur2d(mask2, (21,21), sigma=(10,10))

out = self.input(latent1[0])
out1, _ = self.conv1(out, latent1[0], noise=noise[0])
out2, _ = self.conv1(out, latent2[0], noise=noise[0])
alpha1 = F.interpolate(mask1, size=out1.size()[2:], mode='bilinear')
alpha2 = F.interpolate(mask2, size=out1.size()[2:], mode='bilinear')
bbox1 = get_bbox_from_mask(alpha1)
bbox2 = get_bbox_from_mask(alpha2)
h = max(bbox1[2], bbox2[2])
w = max(bbox1[3], bbox2[3])
out = (1-alpha1)*out1
out[..., bbox1[0]:bbox1[0]+h, bbox1[1]:bbox1[1]+w] += (alpha2*out2)[..., bbox2[0]:bbox2[0]+h, bbox2[1]:bbox2[1]+w]


skip1 = self.to_rgb1(out, latent1[1])
skip2 = self.to_rgb1(out, latent2[1])
alpha1 = F.interpolate(mask1, size=skip1.size()[2:], mode='bilinear')
alpha2 = F.interpolate(mask2, size=skip1.size()[2:], mode='bilinear')
bbox1 = get_bbox_from_mask(alpha1)
bbox2 = get_bbox_from_mask(alpha2)
h = max(bbox1[2], bbox2[2])
w = max(bbox1[3], bbox2[3])
skip = (1-alpha1)*skip1
skip[..., bbox1[0]:bbox1[0]+h, bbox1[1]:bbox1[1]+w] += (alpha2*skip2)[..., bbox2[0]:bbox2[0]+h, bbox2[1]:bbox2[1]+w]

i = 2
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):

out1, _ = conv1(out, latent1[i], noise=noise1)
out2, _ = conv1(out, latent2[i], noise=noise1)
alpha1 = F.interpolate(mask1, size=out1.size()[2:], mode='bilinear')
alpha2 = F.interpolate(mask2, size=out1.size()[2:], mode='bilinear')
bbox1 = get_bbox_from_mask(alpha1)
bbox2 = get_bbox_from_mask(alpha2)
h = max(bbox1[2], bbox2[2])
w = max(bbox1[3], bbox2[3])
out = (1-alpha1)*out1
out[..., bbox1[0]:bbox1[0]+h, bbox1[1]:bbox1[1]+w] += (alpha2*out2)[..., bbox2[0]:bbox2[0]+h, bbox2[1]:bbox2[1]+w]

out1, _ = conv2(out, latent1[i+1], noise=noise2)
out2, _ = conv2(out, latent2[i+1], noise=noise2)
alpha1 = F.interpolate(mask1, size=out1.size()[2:], mode='bilinear')
alpha2 = F.interpolate(mask2, size=out1.size()[2:], mode='bilinear')
bbox1 = get_bbox_from_mask(alpha1)
bbox2 = get_bbox_from_mask(alpha2)
h = max(bbox1[2], bbox2[2])
w = max(bbox1[3], bbox2[3])
out = (1-alpha1)*out1
out[..., bbox1[0]:bbox1[0]+h, bbox1[1]:bbox1[1]+w] += (alpha2*out2)[..., bbox2[0]:bbox2[0]+h, bbox2[1]:bbox2[1]+w]

skip1 = to_rgb(out, latent1[i+2], skip)
skip2 = to_rgb(out, latent2[i+2], skip)
alpha1 = F.interpolate(mask1, size=skip1.size()[2:], mode='bilinear')
alpha2 = F.interpolate(mask2, size=skip1.size()[2:], mode='bilinear')
bbox1 = get_bbox_from_mask(alpha1)
bbox2 = get_bbox_from_mask(alpha2)
h = max(bbox1[2], bbox2[2])
w = max(bbox1[3], bbox2[3])
skip = (1-alpha1)*skip1
skip[..., bbox1[0]:bbox1[0]+h, bbox1[1]:bbox1[1]+w] += (alpha2*skip2)[..., bbox2[0]:bbox2[0]+h, bbox2[1]:bbox2[1]+w]

i += 3

image = skip.clamp(-1,1)
return image

def patch_swap(self, latent1, latent2, coord, swap=True):
noise = [getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)]

Expand Down Expand Up @@ -731,10 +636,15 @@ def singan(self, latent, mode):



def blend_mask(self, latent1, latent2, coord, num_blend=99, pose_align=False, pose_num=4):
def blend_bbox(self, latent1, latent2, coord, model_type, num_blend=99):
noise = [getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)]

coord = coord.astype('uint8')
if mode_type == 'face':
pose_align = True
pose_num = 4
else:
pose_align = False

x, y, w, h = coord[0]

mask = torch.zeros([1,1,256,256]).cuda()
Expand Down
26 changes: 24 additions & 2 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
import scipy
import scipy.ndimage

import torchvision

# Number of style channels per StyleGAN layer
style2list_len = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512,
Expand All @@ -19,12 +19,22 @@
rgb_layer_idx = [1,4,7,10,13,16,19,22,25]

google_drive_paths = {
"stylegan2-church-config-f.pt": "https://drive.google.com/uc?id=1ORsZHZEeFNEX9HtqRutt1jMgrf5Gpcat",
"church.pt": "https://drive.google.com/uc?id=1ORsZHZEeFNEX9HtqRutt1jMgrf5Gpcat",
"face.pt": "https://drive.google.com/uc?id=1dOBo4xWUwM7-BwHWZgp-kV1upaD6tHAh",
"landscape.pt": "https://drive.google.com/uc?id=1rN5EhwiY95BBNPvOezhX4SZ_tEOR0qe2",
"disney.pt": "https://drive.google.com/uc?id=1n2uQ5s2XdUBGIcZA9Uabz1mkjVvKWFeG",
"010000.pt": "https://drive.google.com/uc?id=1hOq8zx0wVS3zqdfASXhzFre7DPi7Sel_",
"model_ir_se50.pt": "https://drive.google.com/uc?id=1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn",
"dlibshape_predictor_68_face_landmarks.dat": "https://drive.google.com/uc?id=11BDmNKS1zxSZxkgsEvQoKgFd8J264jKp",
"e4e_ffhq_encode.pt": "https://drive.google.com/uc?id=1cUv_reLE6k3604or78EranS7XzuVMWeO"
}

@torch.no_grad()
def load_model(generator, model_file_path):
ensure_checkpoint_exists(model_file_path)
ckpt = torch.load(model_file_path, map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["g_ema"], strict=False)
return generator.mean_latent(50000)

def ensure_checkpoint_exists(model_weights_filename):
if not os.path.isfile(model_weights_filename) and (
Expand Down Expand Up @@ -330,3 +340,15 @@ def align_face(filepath, output_size=512):
# Return aligned image.
return img

def normalize(x):
return (x+1)/2

def tensor2bbox_im(x):
return np.array(torchvision.transforms.functional.to_pil_image(normalize(x[0])))

def prepare_bbox(boxes):
output = []
for i in range(len(boxes)):
y1,x1,y2,x2 = boxes[i][0]
output.append((256*np.array([x1,y1, x2-x1, y2-y1])).astype(np.uint8))
return output

0 comments on commit 0b4f59a

Please sign in to comment.