Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsupervised Heterogeneous Graph Learning #3189

Merged
merged 26 commits into from
Mar 27, 2022
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
10bbdea
added unsupervised hetero method
Yonggie Sep 21, 2021
4058254
used 2 meta paths
Yonggie Sep 22, 2021
33c7fee
used 2 meta paths
Yonggie Sep 22, 2021
6ddf312
used 2 meta paths
Yonggie Sep 22, 2021
e8e8145
used 2 meta paths
Yonggie Sep 22, 2021
b469c5a
used 2 meta paths
Yonggie Sep 22, 2021
e79a830
Merge branch 'master' into master
rusty1s Feb 4, 2022
5d9ca86
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2022
4388a4e
Merge branch 'master' into master
Yonggie Mar 10, 2022
8804228
added unsupervised hetero method dmgi, which seems not working
Yonggie Mar 10, 2022
0599874
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2022
6f3c508
re-implemented DMGI from AAAI, seems not working
Yonggie Mar 10, 2022
3c85311
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2022
229cb4d
Merge branch 'master' into master
Yonggie Mar 11, 2022
da6e6eb
changed dmgi encoder from single to multiple.
Yonggie Mar 11, 2022
e93bca5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2022
a43b74e
make it flake8 style.
Yonggie Mar 11, 2022
62775d5
make it flake8 style
Yonggie Mar 11, 2022
4980f99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2022
d39120e
Merge branch 'master' into master
rusty1s Mar 11, 2022
dd16651
Merge branch 'master' into master
Yonggie Mar 12, 2022
15ebaeb
Merge branch 'master' into master
rusty1s Mar 14, 2022
de50f85
update
rusty1s Mar 27, 2022
e9d614f
typo
rusty1s Mar 27, 2022
16e6315
typo
rusty1s Mar 27, 2022
690f14b
Merge branch 'master' into master
rusty1s Mar 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
re-implemented DMGI from AAAI, seems not working
  • Loading branch information
Yonggie committed Mar 10, 2022
commit 6f3c508e8ce67cf275bb3b8a820bab6634f24e7b
73 changes: 33 additions & 40 deletions examples/hetero/hetero_unsupervised_dblp.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
import copy
import math
import os.path as osp

import copy
import torch_geometric.transforms as T
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from torch.nn import Parameter

rusty1s marked this conversation as resolved.
Show resolved Hide resolved
import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import GATConv, GCNConv, HeteroConv, SAGEConv
from torch_geometric.nn import GATConv, GCNConv

EPS = 1e-15
FEATURE_DIM = 64
EMBED_DIM = 32
FEATURE_DIM=64
EMBED_DIM=32

path = osp.join(osp.dirname(osp.realpath(__file__)), 'data/DBLP')
dataset = DBLP(path)
data = dataset[0]
print(data)

# Initialize author node features
author_len=data['author'].num_nodes
data['author'].x = torch.ones(data['author'].num_nodes, FEATURE_DIM)
# Select metapath APCPA and APA as example metapaths.
metapaths = [[("author", "paper"), ("paper", "conference"),
("conference", 'paper'), ('paper', 'author')],
[("author", "paper"), ("paper", "author")]]
metapaths = [[("author", "paper"), ("paper", "conference"),("conference",'paper'),('paper','author')],
[("author", "paper"), ("paper", "author")]]
data = T.AddMetaPaths(metapaths)(data)


def uniform(size, tensor):
if tensor is not None:
bound = 1.0 / math.sqrt(size)
Expand All @@ -45,7 +43,6 @@ def __init__(self, in_channels, out_channels, heads):
def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
return self.activation(self.conv(x, edge_index))


class GCNEncoder(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(GCNEncoder, self).__init__()
Expand All @@ -61,6 +58,7 @@ class HeteroUnsupervised(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()

self.Z=torch.nn.Parameter(torch.Tensor(author_len,out_channels))
# for discriminator
self.embed_size = out_channels
self.weight = Parameter(torch.Tensor(out_channels, out_channels))
Yonggie marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -69,15 +67,18 @@ def __init__(self, in_channels, out_channels):
# for encoder
self.encoder = GCNEncoder(in_channels, out_channels)


def forward(self, x_dict, edge_index_dict):
target_embeds = x_dict['author']

edge_index_a = edge_index_dict['author', 'metapath_0', 'author']
edge_index_b = edge_index_dict['author', 'metapath_1', 'author']
edge_index_a = edge_index_dict['author','metapath_0','author']
edge_index_b = edge_index_dict['author','metapath_1','author']

mp_edge_pairs = [edge_index_a, edge_index_b]

pos_feats = [target_embeds for _ in range(len(mp_edge_pairs))]
pos_feats = []
for _ in range(len(mp_edge_pairs)):
pos_feats.append(copy.deepcopy(target_embeds))
pos_embeds = []
neg_embeds = []
summaries = []
Expand All @@ -97,30 +98,15 @@ def forward(self, x_dict, edge_index_dict):
return pos_embeds, neg_embeds, summaries

def embed(self, x_dict, edge_index_dict):
target_embeds = x_dict['author']

edge_index_a = edge_index_dict['author', 'metapath_0', 'author']
edge_index_b = edge_index_dict['author', 'metapath_1', 'author']

mp_edge_pairs = [edge_index_a, edge_index_b]

pos_feats = []
for _ in range(len(mp_edge_pairs)):
pos_feats.append(copy.deepcopy(target_embeds))
pos_embeds = []
for pos_feat, edge_index in zip(pos_feats, mp_edge_pairs):
pos_embed = self.encoder(pos_feat, edge_index)
pos_embeds.append(pos_embed)
return self.Z

# mean aggeragation
final_embed = sum(pos_embeds) / len(mp_edge_pairs)
return final_embed

def discriminate(self, z, summary, sigmoid=True):
value = torch.matmul(z, torch.matmul(self.weight, summary))
return torch.sigmoid(value) if sigmoid else value

def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.Z)
uniform(self.embed_size, self.weight)

def loss(self, pos_embeds, neg_embeds, summaries):
Expand All @@ -135,10 +121,17 @@ def loss(self, pos_embeds, neg_embeds, summaries):
neg_embed, summary, sigmoid=True) + EPS).mean()
total_loss += (pos_loss + neg_loss)

pos_mean=torch.stack(pos_embeds).mean(dim=0)
neg_mean=torch.stack(neg_embeds).mean(dim=0)
# consensus regularizer
pos_reg_loss = ((self.Z - pos_mean) ** 2).sum()
neg_reg_loss = ((self.Z - neg_mean) ** 2).sum()
reg_loss = pos_reg_loss - neg_reg_loss
total_loss+=reg_loss
return total_loss


model = HeteroUnsupervised(out_channels=FEATURE_DIM, in_channels=FEATURE_DIM)
model = HeteroUnsupervised(out_channels=FEATURE_DIM,in_channels=FEATURE_DIM)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand All @@ -157,9 +150,9 @@ def train():
model(data.x_dict, data.edge_index_dict)

# pick out the train embed
mask = data['author']['train_mask']
pos_embeds = [pos_embed[mask] for pos_embed in pos_embeds]
neg_embeds = [neg_embed[mask] for neg_embed in neg_embeds]
# mask = data['author']['train_mask']
# pos_embeds = [pos_embed[mask] for pos_embed in pos_embeds]
# neg_embeds = [neg_embed[mask] for neg_embed in neg_embeds]

loss = model.loss(pos_embeds, neg_embeds, summaries)
loss.backward()
Expand All @@ -177,8 +170,8 @@ def test():
labels = data['author'].y[mask]

valid_embed = pos_embed[mask]
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
valid_embed = valid_embed.cpu().numpy()
labels = labels.cpu().numpy()
valid_embed=valid_embed.cpu().numpy()
labels=labels.cpu().numpy()

train_z, test_z, train_y, test_y = \
train_test_split(valid_embed, labels, train_size=0.8)
Expand All @@ -190,8 +183,8 @@ def test():
labels = data['author'].y[mask]

test_embed = pos_embed[mask]
test_embed = test_embed.cpu().numpy()
labels = labels.cpu().numpy()
test_embed=test_embed.cpu().numpy()
labels=labels.cpu().numpy()

train_z, test_z, train_y, test_y = \
train_test_split(test_embed, labels, train_size=0.8)
Expand All @@ -203,7 +196,7 @@ def test():
return val_score, test_score


for epoch in range(1, 1000):
for epoch in range(1, 100):
loss = train()
valid_acc, test_acc = test()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f},\
Expand Down