Skip to content

Commit

Permalink
✨ add GAT
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanBilot committed Mar 28, 2022
1 parent a9680c4 commit 115f2c0
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 11 deletions.
1 change: 1 addition & 0 deletions phishGNN/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from .gcn import GCN
from .gin import GIN
from .gat import GAT
from .mem_pool import MemPool
97 changes: 97 additions & 0 deletions phishGNN/models/gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_mean_pool

# dataset = 'Cora'
# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
# dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
# data = dataset[0]


class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
super(GAT, self).__init__()

self.device = device
self.to(device)

self.conv1 = GATConv(in_channels, hidden_channels, heads=hidden_channels, dropout=0.6)
# On the Pubmed dataset, use heads=8 in conv2.
self.conv2 = GATConv(
hidden_channels * hidden_channels, out_channels, heads=1, concat=False, dropout=0.6)
self.embeddings = None

def forward(self, x, edge_index, batch):
x = x.float()
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)

x = global_mean_pool(x, batch)
self.embeddings = x

return x
# return F.log_softmax(x, dim=-1)


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = Net(dataset.num_features, dataset.num_classes).to(device)
# data = data.to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)

# @torch.no_grad()
# def test(data):
# self.eval()
# out, accs = self(data.x, data.edge_index), []
# for _, mask in data('train_mask', 'val_mask', 'test_mask'):
# acc = float((out[mask].argmax(-1) == data.y[mask]).sum() / mask.sum())
# accs.append(acc)
# return accs


# for epoch in range(1, 201):
# train(data)
# train_acc, val_acc, test_acc = test(data)
# print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
# f'Test: {test_acc:.4f}')
28 changes: 17 additions & 11 deletions phishGNN/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_geometric.loader import DataLoader

from visualization import visualize, plot_embeddings
from models import GCN, GIN, MemPool
from models import GCN, GIN, GAT, MemPool


if __name__ == "__main__":
Expand All @@ -22,31 +22,37 @@

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MemPool(
in_channels=dataset.num_features,
hidden_channels=32,
out_channels=dataset.num_classes,
device=device,
)
# model = GIN(
# model = GAT(
# in_channels=dataset.num_features,
# hidden_channels=8,
# out_channels=dataset.num_classes,
# device=device,
# )
# model = MemPool(
# in_channels=dataset.num_features,
# hidden_channels=32,
# out_channels=dataset.num_classes,
# device=device,
# )
# model = GCN(
# in_channels=dataset.num_node_features,
# model = GIN(
# in_channels=dataset.num_features,
# hidden_channels=32,
# out_channels=dataset.num_classes,
# device=device,
# )
model = GCN(
in_channels=dataset.num_features,
hidden_channels=32,
out_channels=dataset.num_classes,
device=device,
)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=4e-5)
loss_fn = torch.nn.CrossEntropyLoss()


for epoch in range(1, 2):
for epoch in range(1, 5):
loss = model.fit(train_loader, optimizer, loss_fn, device)
train_acc = model.test(train_loader)
test_acc = model.test(test_loader)
Expand Down

0 comments on commit 115f2c0

Please sign in to comment.