Skip to content

Commit

Permalink
Better code quality, less dependency. (NVIDIA#12)
Browse files Browse the repository at this point in the history
* Removed the dependency on opencv

The project's dependencies on opencv are sparse and this makes `Propagator`'s logic consistent with `smooth_filter`.

* Redundant code removed.

* Duplicated code in examples merged. Better timing.
  • Loading branch information
suquark authored and mingyuliutw committed Feb 24, 2018
1 parent 182b5c8 commit 83b5aba
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 213 deletions.
2 changes: 1 addition & 1 deletion USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses

### Setup

We only test our code in the following environment.
We only tested our code in the following environment.
- OS: Ubuntu 16.04
- CUDA: 9.1
- **Python 2 from Anaconda2**
Expand Down
66 changes: 11 additions & 55 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,11 @@
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import os
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.utils as utils

import argparse
import time
import numpy as np
import cv2
from PIL import Image
from photo_wct import PhotoWCT
from photo_smooth import Propagator
from smooth_filter import smooth_filter

import process_stylization
from photo_wct import PhotoWCT

parser = argparse.ArgumentParser(description='Photorealistic Image Stylization')
parser.add_argument('--vgg1', default='./models/vgg_normalised_conv1_1_mask.t7', help='Path to the VGG conv1_1')
Expand All @@ -37,48 +28,13 @@

# Load model
p_wct = PhotoWCT(args)
p_pro = Propagator()
p_wct.cuda(0)

content_image_path = args.content_image_path
content_seg_path = args.content_seg_path
style_image_path = args.style_image_path
style_seg_path = args.style_seg_path
output_image_path = args.output_image_path

# Load image
cont_img = Image.open(content_image_path).convert('RGB')
styl_img = Image.open(style_image_path).convert('RGB')
try:
cont_seg = Image.open(content_seg_path)
styl_seg = Image.open(style_seg_path)
except:
cont_seg = []
styl_seg = []

cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)
cont_img = Variable(cont_img.cuda(0), volatile=True)
styl_img = Variable(styl_img.cuda(0), volatile=True)

cont_seg = np.asarray(cont_seg)
styl_seg = np.asarray(styl_seg)

start_style_time = time.time()
stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg)
end_style_time = time.time()
print('Elapsed time in stylization: %f' % (end_style_time - start_style_time))
utils.save_image(stylized_img.data.cpu().float(), output_image_path, nrow=1)

start_propagation_time = time.time()
out_img = p_pro.process(output_image_path, content_image_path)
end_propagation_time = time.time()
print('Elapsed time in propagation: %f' % (end_propagation_time - start_propagation_time))
cv2.imwrite(output_image_path, out_img)

start_postprocessing_time = time.time()
out_img = smooth_filter(output_image_path, content_image_path, f_radius=15, f_edge=1e-1)
end_postprocessing_time = time.time()
print('Elapsed time in post processing: %f' % (end_postprocessing_time - start_postprocessing_time))

out_img.save(output_image_path)
process_stylization.stylization(
p_wct=p_wct,
content_image_path=args.content_image_path,
style_image_path=args.style_image_path,
content_seg_path=args.content_seg_path,
style_seg_path=args.style_seg_path,
output_image_path=args.output_image_path,
)
Loading

0 comments on commit 83b5aba

Please sign in to comment.