Skip to content

Commit

Permalink
PhotoWCT CPU support. Pure CPU mode is ~10x slower.
Browse files Browse the repository at this point in the history
  • Loading branch information
suquark authored and mingyuliutw committed Mar 2, 2018
1 parent a167b7b commit 88b1da7
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions photo_wct.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,42 +63,43 @@ def __compute_label_info(self, cont_seg, styl_seg):
for l in self.label_set:
# if l==0:
# continue
is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100
o_cont_mask = np.where(cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l)
o_styl_mask = np.where(styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l)
if o_cont_mask[0].size <= 10 or o_styl_mask[0].size <= 10 or \
self.__large_dff(o_cont_mask[0].size, o_styl_mask[0].size):
continue
self.label_indicator[l] = 1
self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size)

def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2)
styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2)
cont_feat_view = cont_feat.view(cont_c, -1).clone()
styl_feat_view = styl_feat.view(styl_c, -1).clone()
target_feature = cont_feat.view(cont_c, -1).clone()

if cont_seg.size == False or styl_seg.size == False:
tmp_target_feature = self.__wct_core(cont_feat_view, styl_feat_view)
target_feature = tmp_target_feature.view_as(cont_feat)
ccsF = target_feature.float().unsqueeze(0)
return ccsF

t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST))
t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST))

for l in self.label_set:
if self.label_indicator[l] == 0:
continue
cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l)
styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l)
if cont_mask[0].size <= 0 or styl_mask[0].size <= 0:
continue
cont_indi = torch.LongTensor(cont_mask[0]).cuda(0)
styl_indi = torch.LongTensor(styl_mask[0]).cuda(0)
cFFG = torch.index_select(cont_feat_view, 1, cont_indi)
sFFG = torch.index_select(styl_feat_view, 1, styl_indi)
tmp_target_feature = self.__wct_core(cFFG, sFFG)
target_feature.index_copy_(1, cont_indi, tmp_target_feature)
target_feature = self.__wct_core(cont_feat_view, styl_feat_view)
else:
target_feature = cont_feat.view(cont_c, -1).clone()

t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST))
t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST))

for l in self.label_set:
if self.label_indicator[l] == 0:
continue
cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l)
styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l)
if cont_mask[0].size <= 0 or styl_mask[0].size <= 0:
continue

cont_indi = torch.LongTensor(cont_mask[0])
styl_indi = torch.LongTensor(styl_mask[0])
if self.is_cuda:
cont_indi = cont_indi.cuda(0)
styl_indi = styl_indi.cuda(0)

cFFG = torch.index_select(cont_feat_view, 1, cont_indi)
sFFG = torch.index_select(styl_feat_view, 1, styl_indi)
tmp_target_feature = self.__wct_core(cFFG, sFFG)
target_feature.index_copy_(1, cont_indi, tmp_target_feature)

target_feature = target_feature.view_as(cont_feat)
ccsF = target_feature.float().unsqueeze(0)
Expand All @@ -110,7 +111,10 @@ def __wct_core(self, cont_feat, styl_feat):
c_mean = c_mean.unsqueeze(1).expand_as(cont_feat)
cont_feat = cont_feat - c_mean

iden = torch.eye(cFSize[0]).cuda() # .double()
iden = torch.eye(cFSize[0]) # .double()
if self.is_cuda:
iden = iden.cuda()

contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden
# del iden
c_u, c_e, c_v = torch.svd(contentConv, some=False)
Expand Down Expand Up @@ -145,6 +149,6 @@ def __wct_core(self, cont_feat, styl_feat):
targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
return targetFeature

def __large_dff(self, a, b):
return a / b >= 100 or b / a >= 100

@property
def is_cuda(self):
return next(self.parameters()).is_cuda

0 comments on commit 88b1da7

Please sign in to comment.