From 88b1da7b081b0e7b19dab9063fafd0e4bdc21801 Mon Sep 17 00:00:00 2001 From: suquark Date: Sat, 24 Feb 2018 20:35:34 +0800 Subject: [PATCH] PhotoWCT CPU support. Pure CPU mode is ~10x slower. --- photo_wct.py | 64 ++++++++++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/photo_wct.py b/photo_wct.py index 0c038a9..2fb6fda 100644 --- a/photo_wct.py +++ b/photo_wct.py @@ -63,42 +63,43 @@ def __compute_label_info(self, cont_seg, styl_seg): for l in self.label_set: # if l==0: # continue + is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100 o_cont_mask = np.where(cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l) o_styl_mask = np.where(styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l) - if o_cont_mask[0].size <= 10 or o_styl_mask[0].size <= 10 or \ - self.__large_dff(o_cont_mask[0].size, o_styl_mask[0].size): - continue - self.label_indicator[l] = 1 + self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size) def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg): cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2) styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2) cont_feat_view = cont_feat.view(cont_c, -1).clone() styl_feat_view = styl_feat.view(styl_c, -1).clone() - target_feature = cont_feat.view(cont_c, -1).clone() if cont_seg.size == False or styl_seg.size == False: - tmp_target_feature = self.__wct_core(cont_feat_view, styl_feat_view) - target_feature = tmp_target_feature.view_as(cont_feat) - ccsF = target_feature.float().unsqueeze(0) - return ccsF - - t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST)) - t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST)) - - for l in self.label_set: - if self.label_indicator[l] == 0: - continue - cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l) - styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l) - if cont_mask[0].size <= 0 or styl_mask[0].size <= 0: - continue - cont_indi = torch.LongTensor(cont_mask[0]).cuda(0) - styl_indi = torch.LongTensor(styl_mask[0]).cuda(0) - cFFG = torch.index_select(cont_feat_view, 1, cont_indi) - sFFG = torch.index_select(styl_feat_view, 1, styl_indi) - tmp_target_feature = self.__wct_core(cFFG, sFFG) - target_feature.index_copy_(1, cont_indi, tmp_target_feature) + target_feature = self.__wct_core(cont_feat_view, styl_feat_view) + else: + target_feature = cont_feat.view(cont_c, -1).clone() + + t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST)) + t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST)) + + for l in self.label_set: + if self.label_indicator[l] == 0: + continue + cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l) + styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l) + if cont_mask[0].size <= 0 or styl_mask[0].size <= 0: + continue + + cont_indi = torch.LongTensor(cont_mask[0]) + styl_indi = torch.LongTensor(styl_mask[0]) + if self.is_cuda: + cont_indi = cont_indi.cuda(0) + styl_indi = styl_indi.cuda(0) + + cFFG = torch.index_select(cont_feat_view, 1, cont_indi) + sFFG = torch.index_select(styl_feat_view, 1, styl_indi) + tmp_target_feature = self.__wct_core(cFFG, sFFG) + target_feature.index_copy_(1, cont_indi, tmp_target_feature) target_feature = target_feature.view_as(cont_feat) ccsF = target_feature.float().unsqueeze(0) @@ -110,7 +111,10 @@ def __wct_core(self, cont_feat, styl_feat): c_mean = c_mean.unsqueeze(1).expand_as(cont_feat) cont_feat = cont_feat - c_mean - iden = torch.eye(cFSize[0]).cuda() # .double() + iden = torch.eye(cFSize[0]) # .double() + if self.is_cuda: + iden = iden.cuda() + contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden # del iden c_u, c_e, c_v = torch.svd(contentConv, some=False) @@ -145,6 +149,6 @@ def __wct_core(self, cont_feat, styl_feat): targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature) return targetFeature - def __large_dff(self, a, b): - return a / b >= 100 or b / a >= 100 - + @property + def is_cuda(self): + return next(self.parameters()).is_cuda