-
Notifications
You must be signed in to change notification settings - Fork 2
/
perceptual_model.py
78 lines (63 loc) · 3.74 KB
/
perceptual_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import tensorflow as tf
from keras.models import Model
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.preprocessing import image
import keras.backend as K
def load_images(images_list, img_size):
loaded_images = list()
for img_path in images_list:
img = image.load_img(img_path, target_size=(img_size, img_size))
img = np.expand_dims(img, 0)
loaded_images.append(img)
loaded_images = np.vstack(loaded_images)
preprocessed_images = preprocess_input(loaded_images)
return preprocessed_images
class PerceptualModel:
def __init__(self, img_size, layer=9, batch_size=1, sess=None):
self.sess = tf.get_default_session() if sess is None else sess
K.set_session(self.sess)
self.img_size = img_size
self.layer = layer
self.batch_size = batch_size
self.perceptual_model = None
self.ref_img_features = None
self.features_weight = None
self.loss = None
def build_perceptual_model(self, generated_image_tensor):
vgg16 = VGG16(include_top=False, input_shape=(self.img_size, self.img_size, 3))
self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output)
generated_image = preprocess_input(tf.image.resize_images(generated_image_tensor,
(self.img_size, self.img_size), method=1))
generated_img_features = self.perceptual_model(generated_image)
self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape,
dtype='float32', initializer=tf.initializers.zeros())
self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape,
dtype='float32', initializer=tf.initializers.zeros())
self.sess.run([self.features_weight.initializer, self.features_weight.initializer])
self.loss = tf.losses.mean_squared_error(self.features_weight * self.ref_img_features,
self.features_weight * generated_img_features) / 82890.0
def set_reference_images(self, images_list):
assert(len(images_list) != 0 and len(images_list) <= self.batch_size)
loaded_image = load_images(images_list, self.img_size)
image_features = self.perceptual_model.predict_on_batch(loaded_image)
# in case if number of images less than actual batch size
# can be optimized further
weight_mask = np.ones(self.features_weight.shape)
if len(images_list) != self.batch_size:
features_space = list(self.features_weight.shape[1:])
existing_features_shape = [len(images_list)] + features_space
empty_features_shape = [self.batch_size - len(images_list)] + features_space
existing_examples = np.ones(shape=existing_features_shape)
empty_examples = np.zeros(shape=empty_features_shape)
weight_mask = np.vstack([existing_examples, empty_examples])
image_features = np.vstack([image_features, np.zeros(empty_features_shape)])
self.sess.run(tf.assign(self.features_weight, weight_mask))
self.sess.run(tf.assign(self.ref_img_features, image_features))
def optimize(self, vars_to_optimize, iterations=500, learning_rate=1.):
vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize]
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize])
for _ in range(iterations):
_, loss = self.sess.run([min_op, self.loss])
yield loss