Skip to content

Commit

Permalink
HAN
Browse files Browse the repository at this point in the history
  • Loading branch information
taishan1994 committed Sep 19, 2020
1 parent 983a69b commit c2c9d9a
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 10 deletions.
72 changes: 71 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,72 @@
# pytorch_HAN
Heterogeneous Graph Attention Network (HAN) with pytorch
Paper address:<br>
<a href="https://github.com/Jhy1993/Representation-Learning-on-Heterogeneous-Graph">https://github.com/Jhy1993/Representation-Learning-on-Heterogeneous-Graph</a><br>
Heterogeneous Graph Attention Network (HAN) with pytorch. If you want to pursue the performance in the original paper,
this may not be suitable for you, because there is still a problem: training loss decreases, but verification loss increases.<br>

If you just want to figure out the basic principles of HAN and how to change tensorflow code to pytorch code, then this is for you.
I implemented it according to the original tensorflow code structure.<br>

If you want to pursue higher performance, please refer to:<br>
Official tensorflow implementation:<a href="https://github.com/Jhy1993/HAN">https://github.com/Jhy1993/HAN</a><br>
DGL implementation:<a href="https://github.com/dmlc/dgl/tree/master/examples/pytorch/han">https://github.com/dmlc/dgl/tree/master/examples/pytorch/han</a><br>

# The result
Address of ACM data set:
Preprocessed ACM can be found in: <a href="https://pan.baidu.com/s/1V2iOikRqHPtVvaANdkzROw">https://pan.baidu.com/s/1V2iOikRqHPtVvaANdkzROw</a> 提取码:50k2<br>

You can use the following command to run:
```python
python main.py
```

Training result:
```python
600 300 2125
y_train:(3025, 3), y_val:(3025, 3), y_test:(3025, 3), train_idx:(1, 600), val_idx:(1, 300), test_idx:(1, 2125)
2
model: pre_trained/acm/acm_allMP_multi_fea_.ckpt
fea_list[0].shape torch.Size([1, 1870, 3025])
biases_list[0].shape: torch.Size([1, 3025, 3025])
3
2
torch.Size([1, 1870, 3025]) torch.Size([1, 3025, 3025])
torch.Size([1, 1870, 3025]) torch.Size([1, 3025, 3025])
训练节点个数: 600
验证节点个数: 300
测试节点个数: 2125
epoch:001, loss:1.1004, TrainAcc:0.3517, ValLoss:1.1022, ValAcc:0.4000
epoch:002, loss:1.0762, TrainAcc:0.4250, ValLoss:1.1980, ValAcc:0.0533
epoch:003, loss:1.0007, TrainAcc:0.6300, ValLoss:1.4572, ValAcc:0.0533
epoch:004, loss:0.8876, TrainAcc:0.6583, ValLoss:2.0040, ValAcc:0.0500
epoch:005, loss:0.8145, TrainAcc:0.6350, ValLoss:2.7091, ValAcc:0.0500
epoch:006, loss:0.7897, TrainAcc:0.6267, ValLoss:3.2186, ValAcc:0.0500
epoch:007, loss:0.7804, TrainAcc:0.6150, ValLoss:3.4550, ValAcc:0.0500
epoch:008, loss:0.7527, TrainAcc:0.6150, ValLoss:3.5096, ValAcc:0.0500
epoch:009, loss:0.7404, TrainAcc:0.6117, ValLoss:3.5125, ValAcc:0.0600
epoch:010, loss:0.7329, TrainAcc:0.6633, ValLoss:3.5349, ValAcc:0.0400
epoch:011, loss:0.7169, TrainAcc:0.6983, ValLoss:3.5743, ValAcc:0.0133
epoch:012, loss:0.6934, TrainAcc:0.6917, ValLoss:3.6612, ValAcc:0.0033
epoch:013, loss:0.6711, TrainAcc:0.6750, ValLoss:3.7738, ValAcc:0.0033
epoch:014, loss:0.6645, TrainAcc:0.6733, ValLoss:3.9418, ValAcc:0.0200
epoch:015, loss:0.6652, TrainAcc:0.6833, ValLoss:4.0934, ValAcc:0.0300
epoch:016, loss:0.6515, TrainAcc:0.6883, ValLoss:4.2498, ValAcc:0.0300
epoch:017, loss:0.6238, TrainAcc:0.7050, ValLoss:4.4304, ValAcc:0.0300
epoch:018, loss:0.6082, TrainAcc:0.7317, ValLoss:4.5820, ValAcc:0.0333
epoch:019, loss:0.6030, TrainAcc:0.7517, ValLoss:4.7110, ValAcc:0.0367
epoch:020, loss:0.5933, TrainAcc:0.7850, ValLoss:4.8053, ValAcc:0.0400
epoch:021, loss:0.5824, TrainAcc:0.8267, ValLoss:4.8781, ValAcc:0.0333
epoch:022, loss:0.5655, TrainAcc:0.8017, ValLoss:4.9006, ValAcc:0.0267
epoch:023, loss:0.5333, TrainAcc:0.8083, ValLoss:4.9148, ValAcc:0.0167
epoch:024, loss:0.5175, TrainAcc:0.8050, ValLoss:4.8788, ValAcc:0.0100
epoch:025, loss:0.4994, TrainAcc:0.8117, ValLoss:4.7670, ValAcc:0.0033
epoch:026, loss:0.4888, TrainAcc:0.8333, ValLoss:4.5965, ValAcc:0.0033
```
This is where the problem lies.<br>
If you know how to solve this problem, please don't hesitate to tell me.






9 changes: 2 additions & 7 deletions layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ def __init__(self,in_channel, out_sz, bias_mat, in_drop=0.0, coef_drop=0.0, acti
self.coef_dropout = nn.Dropout(coef_drop)
self.activation = activation

"""
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
"""

def forward(self,x):
seq = x.float()
Expand Down Expand Up @@ -67,9 +62,9 @@ def __init__(self, inputs, attention_size, time_major=False, return_alphas=False
self.reset_parameters()

def reset_parameters(self):
nn.init.kaiming_uniform_(self.w_omega)
nn.init.xavier_uniform_(self.w_omega)
nn.init.zeros_(self.b_omega)
nn.init.kaiming_uniform_(self.u_omega)
nn.init.xavier_uniform_(self.u_omega)

def forward(self,x):
#print("x.shape:",x.shape)
Expand Down
28 changes: 28 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,35 @@
import torch
import torch.nn as nn
import torch.optim as optim
from model import *
from utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = "cpu"

adj_list, fea_list, y_train, y_val, y_test, train_mask, val_mask, test_mask, my_data = load_data_dblp()
nb_nodes = fea_list[0].shape[0] #节点数目 3025
ft_size = fea_list[0].shape[1] #特征的维度 1870
nb_classes = y_train.shape[1] #标签的数目 3

fea_list = [torch.transpose(torch.from_numpy(fea[np.newaxis]),2,1).to(device) for fea in fea_list]
#fea_list = torch.from_numpy(np.array(fea_list)).to(device)
adj_list = [adj[np.newaxis] for adj in adj_list]
y_train = y_train[np.newaxis]
y_val = y_val[np.newaxis]
y_test = y_test[np.newaxis]
#train_mask = train_mask[np.newaxis]
#val_mask = val_mask[np.newaxis]
#test_mask = test_mask[np.newaxis]

my_labels = my_data['my_labels']
train_my_labels = my_data['train_my_labels']
val_my_labels = my_data['val_my_labels']
test_my_labels = my_data['test_my_labels']


biases_list = [torch.transpose(torch.from_numpy(adj_to_bias(adj, [nb_nodes], nhood=1)),2,1).to(device) for adj in adj_list]
print(len(biases_list))

dataset = 'acm'
featype = 'fea'
Expand Down
3 changes: 2 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import trch.nn as nn
import torch.nn as nn
from layer import *

class HeteGAT_multi(nn.Module):
def __init__(self, inputs_list, nb_classes, nb_nodes, attn_drop, ffd_drop,
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def adj_to_bias(adj, sizes, nhood=1):
mt[g][i][j] = 1.0
return -1e9 * (1.0 - mt)

def load_data_dblp(path='/data02/gob/ACM3025.mat'):
def load_data_dblp(path='data/ACM3025.mat'):
data = sio.loadmat(path)
#truelabels:[3025,3] truefeatures:[3025,1870]
truelabels, truefeatures = data['label'], data['feature'].astype(float)
Expand Down

0 comments on commit c2c9d9a

Please sign in to comment.