-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss_functions.py
113 lines (96 loc) · 4.71 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import torch.nn.functional as F
def loss_weights(task, dsets):
'''
You must set these hyperparameters to apply our method to other datasets.
These hyperparameters may not be the optimal value for your machine.
'''
alpha = dict()
alpha['style'], alpha['dis'], alpha['gen'] = dict(), dict(), dict()
if task == 'clf':
alpha['recon'], alpha['consis'], alpha['content'] = 5, 1, 1
# MNIST <-> MNIST-M
if 'M' in dsets and 'MM' in dsets and 'U' not in dsets:
alpha['style']['M2MM'], alpha['style']['MM2M'] = 5e4, 1e4
alpha['dis']['M'], alpha['dis']['MM'] = 0.5, 0.5
alpha['gen']['M'], alpha['gen']['MM'] = 0.5, 1.0
# MNIST <-> USPS
elif 'M' in dsets and 'U' in dsets and 'MM' not in dsets:
alpha['style']['M2U'], alpha['style']['U2M'] = 5e3, 5e3
alpha['dis']['M'], alpha['dis']['U'] = 0.5, 0.5
alpha['gen']['M'], alpha['gen']['U'] = 0.5, 0.5
# MNIST <-> MNIST-M <-> USPS
elif 'M' in dsets and 'U' in dsets and 'MM' in dsets:
alpha['style']['M2MM'], alpha['style']['MM2M'], alpha['style']['M2U'], alpha['style']['U2M'] = 5e4, 1e4, 1e4, 1e4
alpha['dis']['M'], alpha['dis']['MM'], alpha['dis']['U'] = 0.5, 0.5, 0.5
alpha['gen']['M'], alpha['gen']['MM'], alpha['gen']['U'] = 0.5, 1.0, 0.5
elif task == 'seg':
# GTA5 <-> Cityscapes
alpha['recon'], alpha['consis'], alpha['content'] = 10, 1, 1
alpha['style']['G2C'], alpha['style']['C2G'] = 5e3, 5e3
alpha['dis']['G'], alpha['dis']['C'] = 0.5, 0.5
alpha['gen']['G'], alpha['gen']['C'] = 0.5, 0.5
return alpha
class Loss_Functions:
def __init__(self, args):
self.args = args
self.alpha = loss_weights(args.task, args.datasets)
def recon(self, imgs, recon_imgs):
recon_loss = 0
for dset in imgs.keys():
recon_loss += F.l1_loss(imgs[dset], recon_imgs[dset])
return self.alpha['recon'] * recon_loss
def dis(self, real, fake):
dis_loss = 0
if self.args.task == 'clf': # DCGAN loss
for dset in real.keys():
dis_loss += self.alpha['dis'][dset] * F.binary_cross_entropy(real[dset], torch.ones_like(real[dset]))
for cv in fake.keys():
source, target = cv.split('2')
dis_loss += self.alpha['dis'][target] * F.binary_cross_entropy(fake[cv], torch.zeros_like(fake[cv]))
elif self.args.task == 'seg': # Hinge loss
for dset in real.keys():
dis_loss += self.alpha['dis'][dset] * F.relu(1. - real[dset]).mean()
for cv in fake.keys():
source, target = cv.split('2')
dis_loss += self.alpha['dis'][target] * F.relu(1. + fake[cv]).mean()
return dis_loss
def gen(self, fake):
gen_loss = 0
for cv in fake.keys():
source, target = cv.split('2')
if self.args.task == 'clf':
gen_loss += self.alpha['gen'][target] * F.binary_cross_entropy(fake[cv], torch.ones_like(fake[cv]))
elif self.args.task == 'seg':
gen_loss += -self.alpha['gen'][target] * fake[cv].mean()
return gen_loss
def content_perceptual(self, perceptual, perceptual_converted):
content_perceptual_loss = 0
for cv in perceptual_converted.keys():
source, target = cv.split('2')
content_perceptual_loss += F.mse_loss(perceptual[source][-1], perceptual_converted[cv][-1])
return self.alpha['content'] * content_perceptual_loss
def style_perceptual(self, style_gram, style_gram_converted):
style_percptual_loss = 0
for cv in style_gram_converted.keys():
source, target = cv.split('2')
for gr in range(len(style_gram[target])):
style_percptual_loss += self.alpha['style'][cv] * F.mse_loss(style_gram[target][gr], style_gram_converted[cv][gr])
return style_percptual_loss
def consistency(self, contents, styles, contents_converted, styles_converted, converts):
consistency_loss = 0
for cv in converts:
source, target = cv.split('2')
consistency_loss += F.l1_loss(contents[cv], contents_converted[cv])
consistency_loss += F.l1_loss(styles[target], styles_converted[cv])
return self.alpha['consis'] * consistency_loss
def task(self, pred, gt):
task_loss = 0
for key in pred.keys():
if '2' in key:
source, target = key.split('2')
else:
source = key
task_loss += F.cross_entropy(pred[key], gt[source], ignore_index=-1)
return task_loss