-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
nyxflower
committed
Aug 8, 2019
1 parent
13b9837
commit ce18721
Showing
3 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from data.utils import load_data_torch | ||
import numpy as np | ||
import pickle | ||
import csv | ||
|
||
# drug id - index | ||
with open('../data/index_map/drug-map.pkl', 'rb') as f: | ||
drug_map = pickle.load(f) | ||
inv_drug_map = {v: k for k, v in drug_map.items()} | ||
|
||
# combo id - index | ||
with open('../data/index_map/combo_map.pkl', 'rb') as f: | ||
combo_map = pickle.load(f) | ||
inv_combo_map = {v: k for k, v in combo_map.items()} | ||
|
||
# selected-drug idx - drug idx | ||
with open('../out/decagon_et.pkl', 'rb') as f: # the whole dataset | ||
et_list = pickle.load(f) | ||
inv_et_list = {et_list[i]: i for i in range(len(et_list))} | ||
|
||
feed_dict = load_data_torch("../data/", et_list, mono=False) | ||
|
||
# ###################################################### | ||
# generate polypharmacy side effect id - name map | ||
# combo_name_map = {} | ||
# with open('../data/index_map/bio-decagon-combo.csv', 'r') as f: | ||
# reader = csv.reader(f) | ||
# next(reader) | ||
# for _, _, id, name in reader: | ||
# id = int(id.split('C')[-1]) | ||
# combo_name_map[id] = name | ||
# | ||
# # save map | ||
# with open('../data/index_map/combo-name-map.pkl', 'wb') as f: | ||
# pickle.dump(combo_name_map, f) | ||
|
||
# use map | ||
with open('../data/index_map/combo-name-map.pkl', 'rb') as f: | ||
combo_name_map = pickle.load(f) | ||
|
||
# ###################################################### | ||
# side effect name - original index reported in decagon | ||
decagon_best_name = ["Mumps", "Carbuncle", "Coccydynia", "Tympanic membrane perfor", "Dyshidrosis", "Spondylosis", "Schizoaffective disorder", "Breast dysplasia", "Ganglion", "Uterine polyp"] | ||
decagon_worst_name = ["Bleeding", "Body temperature increased", "Emesis", "Renal disorder", "Leucopenia", "Diarrhea", "Icterus", "Nausea", "Itch", "Anaemia"] | ||
decagon_best_org_id = [26780, 7078, 9193, 206504, 32633, 38019, 36337, 16034, 1258666, 156369] | ||
decagon_worst_org_id = [19080, 15967, 42963, 22658, 23530, 11991, 22346, 27497, 33774, 2871] | ||
|
||
# get index | ||
decagon_best_idx = [inv_et_list[combo_map[i]] for i in decagon_best_org_id] | ||
decagon_worst_idx = [inv_et_list[combo_map[i]] for i in decagon_worst_org_id] | ||
|
||
# ###################################################### | ||
# Evaluation | ||
name = 'RGCN-DistMult on d-net' | ||
with open('../out/dd-rgcn-dist/test_record.pkl', 'rb') as f: | ||
dist_record = pickle.load(f) | ||
auprc = np.array(dist_record[len(dist_record)-1])[0, :] | ||
sorted_idx = np.argsort(auprc, kind='quicksort') | ||
|
||
print(' {:37s} {:6s} | {:45s} {:6s}'.format('The Highest AUPRC Score', ' Edge', 'The Lowest AUPRC Score', ' Edge')) | ||
for i in range(20): | ||
print(' {:30s} {:7.4f} {:6d}| {:38s} {:7.4f} {:6d}'.format( | ||
combo_name_map[inv_combo_map[et_list[sorted_idx[-(i+1)]]]], auprc[sorted_idx[-(i+1)]], feed_dict['dd_adj_list'][-(i+1)].nnz, | ||
combo_name_map[inv_combo_map[et_list[sorted_idx[i]]]], auprc[sorted_idx[i]], feed_dict['dd_adj_list'][i].nnz)) | ||
|
||
decag_best_in_us = [962 - np.where(sorted_idx == i)[0] for i in decagon_best_idx] | ||
decag_worst_in_us = [np.where(sorted_idx == i)[0] for i in decagon_worst_idx] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
from data.utils import load_data_torch | ||
from src.layers import negative_sampling, auprc_auroc_ap | ||
import pickle | ||
from torch.nn import Module | ||
import torch | ||
from src.layers import * | ||
from model.pd_net import MyHierarchyConv | ||
import sys | ||
import time | ||
|
||
sys.setrecursionlimit(8000) | ||
|
||
with open('../out/decagon_et.pkl', 'rb') as f: # the whole dataset | ||
et_list = pickle.load(f) | ||
|
||
# et_list = et_list | ||
feed_dict = load_data_torch("../data/", et_list, mono=True) | ||
data = Data.from_dict(feed_dict) | ||
n_drug, n_drug_feat = data.d_feat.shape | ||
n_prot, n_prot_feat = data.p_feat.shape | ||
n_et_dd = len(et_list) | ||
|
||
data.train_idx, data.train_et, data.train_range,data.test_idx, data.test_et, data.test_range = process_edges(data.dd_edge_index) | ||
|
||
# re-construct node feature | ||
data.p_feat = torch.cat([dense_id(n_prot), torch.zeros(size=(n_drug, n_prot))], dim=0) | ||
data.d_feat = dense_id(n_drug) | ||
n_drug_feat = n_drug | ||
n_prot_feat = n_prot | ||
|
||
# ################################### | ||
# dp_edge_index and range index | ||
# ################################### | ||
data.dp_edge_index = np.array([data.dp_adj.col-1, data.dp_adj.row-1]) | ||
|
||
count_durg = np.zeros(n_drug, dtype=np.int) | ||
for i in data.dp_edge_index[1, :]: | ||
count_durg[i] += 1 | ||
range_list = [] | ||
start = 0 | ||
end = 0 | ||
for i in count_durg: | ||
end += i | ||
range_list.append((start, end)) | ||
start = end | ||
|
||
data.dp_edge_index = torch.from_numpy(data.dp_edge_index + np.array([[0], [n_prot]])) | ||
data.dp_range_list = range_list | ||
|
||
|
||
data.d_norm = torch.ones(n_drug) | ||
data.p_norm = torch.ones(n_prot+n_drug) | ||
# data.x_norm = torch.sqrt(data.d_feat.sum(dim=1)) | ||
# data.d_feat.requires_grad = True | ||
|
||
|
||
source_dim = n_prot_feat | ||
embed_dim = 32 | ||
target_dim = 16 | ||
|
||
|
||
class HierEncoder(Module): | ||
def __init__(self, source_dim, embed_dim, target_dim, | ||
uni_num_source, uni_num_target): | ||
super(HierEncoder, self).__init__() | ||
|
||
self.embed = Param(torch.Tensor(source_dim, embed_dim)) | ||
self.hgcn = MyHierarchyConv(embed_dim, target_dim, uni_num_source, uni_num_target) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.embed.data.normal_() | ||
|
||
def forward(self, source_feat, edge_index, range_list, x_norm): | ||
x = torch.matmul(source_feat, self.embed) | ||
x = x / x_norm.view(-1, 1) | ||
x = self.hgcn(x, edge_index, range_list) | ||
# x = F.relu(x, inplace=True) | ||
|
||
return x | ||
|
||
|
||
class NNDecoder(Module): | ||
def __init__(self, in_dim, num_uni_edge_type, l1_dim=8): | ||
""" in_dim: the feat dim of a drug | ||
num_edge_type: num of dd edge type """ | ||
|
||
super(NNDecoder, self).__init__() | ||
self.l1_dim = l1_dim # Decoder Lays' dim setting | ||
|
||
# parameters | ||
# for drug 1 | ||
self.w1_l1 = Param(torch.Tensor(in_dim, l1_dim)) | ||
self.w1_l2 = Param(torch.Tensor(num_uni_edge_type, l1_dim)) # dd_et | ||
# specified | ||
# for drug 2 | ||
self.w2_l1 = Param(torch.Tensor(in_dim, l1_dim)) | ||
self.w2_l2 = Param(torch.Tensor(num_uni_edge_type, l1_dim)) # dd_et | ||
# specified | ||
|
||
self.reset_parameters() | ||
|
||
def forward(self, z, edge_index, edge_type): | ||
# layer 1 | ||
d1 = torch.matmul(z[edge_index[0]], self.w1_l1) | ||
d2 = torch.matmul(z[edge_index[1]], self.w2_l1) | ||
d1 = F.relu(d1, inplace=True) | ||
d2 = F.relu(d2, inplace=True) | ||
|
||
# layer 2 | ||
d1 = (d1 * self.w1_l2[edge_type]).sum(dim=1) | ||
d2 = (d2 * self.w2_l2[edge_type]).sum(dim=1) | ||
|
||
return torch.sigmoid(d1 + d2) | ||
|
||
def reset_parameters(self): | ||
self.w1_l1.data.normal_() | ||
self.w2_l1.data.normal_() | ||
self.w1_l2.data.normal_(std=1 / np.sqrt(self.l1_dim)) | ||
self.w2_l2.data.normal_(std=1 / np.sqrt(self.l1_dim)) | ||
|
||
|
||
encoder = HierEncoder(source_dim, embed_dim, target_dim, n_prot, n_drug) | ||
decoder = NNDecoder(target_dim, n_et_dd) | ||
model = MyGAE(encoder, decoder) | ||
|
||
device_name = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
# device_name = 'cpu' | ||
print(device_name) | ||
device = torch.device(device_name) | ||
|
||
model = model.to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | ||
data = data.to(device) | ||
|
||
train_out = {} | ||
test_out = {} | ||
|
||
|
||
|
||
def train(): | ||
|
||
model.train() | ||
|
||
optimizer.zero_grad() | ||
|
||
z = model.encoder(data.p_feat, data.dp_edge_index, data.dp_range_list, data.p_norm) | ||
|
||
pos_index = data.train_idx | ||
neg_index = negative_sampling(data.train_idx, n_drug).to(device) | ||
|
||
pos_score = model.decoder(z, pos_index, data.train_et) | ||
neg_score = model.decoder(z, neg_index, data.train_et) | ||
|
||
# pos_loss = F.binary_cross_entropy(pos_score, torch.ones(pos_score.shape[0]).cuda()) | ||
# neg_loss = F.binary_cross_entropy(neg_score, torch.ones(neg_score.shape[0]).cuda()) | ||
pos_loss = -torch.log(pos_score + EPS).mean() | ||
neg_loss = -torch.log(1 - neg_score + EPS).mean() | ||
loss = pos_loss + neg_loss | ||
# loss = pos_loss | ||
|
||
|
||
loss.backward() | ||
optimizer.step() | ||
|
||
|
||
score = torch.cat([pos_score, neg_score]) | ||
pos_target = torch.ones(pos_score.shape[0]) | ||
neg_target = torch.zeros(neg_score.shape[0]) | ||
target = torch.cat([pos_target, neg_target]) | ||
auprc, auroc, ap = auprc_auroc_ap(target, score) | ||
# print(auprc, end=' ') | ||
|
||
print(epoch, ' ', | ||
'loss:', loss.tolist(), ' ', | ||
'auprc:', auprc, ' ', | ||
'auroc:', auroc, ' ', | ||
'ap:', ap) | ||
|
||
train_out[epoch] = [auprc, auroc, ap] | ||
|
||
return z, loss | ||
|
||
|
||
test_neg_index = negative_sampling(data.test_idx, n_drug).to(device) | ||
|
||
|
||
def test(z): | ||
model.eval() | ||
|
||
pos_score = model.decoder(z, data.test_idx, data.test_et) | ||
neg_score = model.decoder(z, test_neg_index, data.test_et) | ||
|
||
pos_target = torch.ones(pos_score.shape[0]) | ||
neg_target = torch.zeros(neg_score.shape[0]) | ||
|
||
score = torch.cat([pos_score, neg_score]) | ||
target = torch.cat([pos_target, neg_target]) | ||
|
||
auprc, auroc, ap = auprc_auroc_ap(target, score) | ||
|
||
return auprc, auroc, ap | ||
|
||
|
||
EPOCH_NUM = 100 | ||
out_dir = '../out/pd-32-16-8-16-963/' | ||
|
||
print('model training ...') | ||
for epoch in range(EPOCH_NUM): | ||
time_begin = time.time() | ||
|
||
z, loss = train() | ||
|
||
auprc, auroc, ap = test(z) | ||
|
||
print(epoch, ' ', | ||
'loss:', loss.tolist(), ' ', | ||
'auprc:', auprc, ' ', | ||
'auroc:', auroc, ' ', | ||
'ap:', ap, ' ', | ||
'time:', time.time() - time_begin, '\n') | ||
|
||
# print(epoch, ' ', | ||
# 'auprc:', auprc) | ||
|
||
test_out[epoch] = [auprc, auroc, ap] | ||
|
||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.