forked from NIPS-BRITS/BRITS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
100 lines (70 loc) · 2.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import time
import utils
import models
import argparse
import data_loader
import pandas as pd
import ujson as json
from sklearn import metrics
from ipdb import set_trace
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type = int, default = 1000)
parser.add_argument('--batch_size', type = int, default = 32)
parser.add_argument('--model', type = str)
args = parser.parse_args()
def train(model):
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
data_iter = data_loader.get_loader(batch_size = args.batch_size)
for epoch in xrange(args.epochs):
model.train()
run_loss = 0.0
for idx, data in enumerate(data_iter):
data = utils.to_var(data)
ret = model.run_on_batch(data, optimizer)
run_loss += ret['loss'].data[0]
print '\r Progress epoch {}, {:.2f}%, average loss {}'.format(epoch, (idx + 1) * 100.0 / len(data_iter), run_loss / (idx + 1.0)),
if epoch % 1 == 0:
evaluate(model, data_iter)
def evaluate(model, val_iter):
model.eval()
labels = []
preds = []
evals = []
imputations = []
for idx, data in enumerate(val_iter):
data = utils.to_var(data)
ret = model.run_on_batch(data, None)
pred = ret['predictions'].data.cpu().numpy()
label = ret['labels'].data.cpu().numpy()
is_train = ret['is_train'].data.cpu().numpy()
eval_masks = ret['eval_masks'].data.cpu().numpy()
eval_ = ret['evals'].data.cpu().numpy()
imputation = ret['imputations'].data.cpu().numpy()
evals += eval_[np.where(eval_masks == 1)].tolist()
imputations += imputation[np.where(eval_masks == 1)].tolist()
# collect test label & prediction
pred = pred[np.where(is_train == 0)]
label = label[np.where(is_train == 0)]
labels += label.tolist()
preds += pred.tolist()
labels = np.asarray(labels).astype('int32')
preds = np.asarray(preds)
print 'AUC {}'.format(metrics.roc_auc_score(labels, preds))
evals = np.asarray(evals)
imputations = np.asarray(imputations)
print 'MAE', np.abs(evals - imputations).mean()
print 'MRE', np.abs(evals - imputations).sum() / np.abs(evals).sum()
def run():
model = getattr(models, args.model).Model()
if torch.cuda.is_available():
model = model.cuda()
train(model)
if __name__ == '__main__':
run()