forked from dreamquark-ai/tabnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
multitask.py
178 lines (156 loc) · 5.77 KB
/
multitask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import numpy as np
from scipy.special import softmax
from pytorch_tabnet.utils import SparsePredictDataset, PredictDataset, filter_weights
from pytorch_tabnet.abstract_model import TabModel
from pytorch_tabnet.multiclass_utils import infer_multitask_output, check_output_dim
from torch.utils.data import DataLoader
import scipy
class TabNetMultiTaskClassifier(TabModel):
def __post_init__(self):
super(TabNetMultiTaskClassifier, self).__post_init__()
self._task = 'classification'
self._default_loss = torch.nn.functional.cross_entropy
self._default_metric = 'logloss'
def prepare_target(self, y):
y_mapped = y.copy()
for task_idx in range(y.shape[1]):
task_mapper = self.target_mapper[task_idx]
y_mapped[:, task_idx] = np.vectorize(task_mapper.get)(y[:, task_idx])
return y_mapped
def compute_loss(self, y_pred, y_true):
"""
Computes the loss according to network output and targets
Parameters
----------
y_pred : list of tensors
Output of network
y_true : LongTensor
Targets label encoded
Returns
-------
loss : torch.Tensor
output of loss function(s)
"""
loss = 0
y_true = y_true.long()
if isinstance(self.loss_fn, list):
# if you specify a different loss for each task
for task_loss, task_output, task_id in zip(
self.loss_fn, y_pred, range(len(self.loss_fn))
):
loss += task_loss(task_output, y_true[:, task_id])
else:
# same loss function is applied to all tasks
for task_id, task_output in enumerate(y_pred):
loss += self.loss_fn(task_output, y_true[:, task_id])
loss /= len(y_pred)
return loss
def stack_batches(self, list_y_true, list_y_score):
y_true = np.vstack(list_y_true)
y_score = []
for i in range(len(self.output_dim)):
score = np.vstack([x[i] for x in list_y_score])
score = softmax(score, axis=1)
y_score.append(score)
return y_true, y_score
def update_fit_params(self, X_train, y_train, eval_set, weights):
output_dim, train_labels = infer_multitask_output(y_train)
for _, y in eval_set:
for task_idx in range(y.shape[1]):
check_output_dim(train_labels[task_idx], y[:, task_idx])
self.output_dim = output_dim
self.classes_ = train_labels
self.target_mapper = [
{class_label: index for index, class_label in enumerate(classes)}
for classes in self.classes_
]
self.preds_mapper = [
{str(index): str(class_label) for index, class_label in enumerate(classes)}
for classes in self.classes_
]
self.updated_weights = weights
filter_weights(self.updated_weights)
def predict(self, X):
"""
Make predictions on a batch (valid)
Parameters
----------
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
-------
results : np.array
Predictions of the most probable class
"""
self.network.eval()
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
results = {}
for data in dataloader:
data = data.to(self.device).float()
output, _ = self.network(data)
predictions = [
torch.argmax(torch.nn.Softmax(dim=1)(task_output), dim=1)
.cpu()
.detach()
.numpy()
.reshape(-1)
for task_output in output
]
for task_idx in range(len(self.output_dim)):
results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]]
# stack all task individually
results = [np.hstack(task_res) for task_res in results.values()]
# map all task individually
results = [
np.vectorize(self.preds_mapper[task_idx].get)(task_res.astype(str))
for task_idx, task_res in enumerate(results)
]
return results
def predict_proba(self, X):
"""
Make predictions for classification on a batch (valid)
Parameters
----------
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
-------
res : list of np.ndarray
"""
self.network.eval()
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
results = {}
for data in dataloader:
data = data.to(self.device).float()
output, _ = self.network(data)
predictions = [
torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy()
for task_output in output
]
for task_idx in range(len(self.output_dim)):
results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]]
res = [np.vstack(task_res) for task_res in results.values()]
return res