Skip to content

Commit

Permalink
✨ add MemPool + move train() and test() to model classes
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanBilot committed Mar 27, 2022
1 parent b63f414 commit fe473a4
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 46 deletions.
1 change: 1 addition & 0 deletions phishGNN/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .gcn import GCN
from .gin import GIN
from .mem_pool import MemPool
40 changes: 39 additions & 1 deletion phishGNN/models/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@


class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
def __init__(self, in_channels, hidden_channels, out_channels, device):
super(GCN, self).__init__()

self.device = device
self.to(device)

torch.manual_seed(12345)
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
Expand All @@ -28,3 +32,37 @@ def forward(self, x, edge_index, batch):
x = self.lin(x)

return x

def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
data = data.to(device)
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)
43 changes: 38 additions & 5 deletions phishGNN/models/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,24 @@


class GIN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
def __init__(self, in_channels, hidden_channels, out_channels, device):
super().__init__()

self.device = device
self.to(device)

self.conv1 = GINConv(
Sequential(Linear(in_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(),
Linear(hidden_channels, hidden_channels), ReLU()))

self.conv2 = GINConv(
Sequential(Linear(hidden_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(),
Linear(hidden_channels, hidden_channels), ReLU()))

self.conv3 = GINConv(
Sequential(Linear(hidden_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(),
Linear(hidden_channels, hidden_channels), ReLU()))

self.conv4 = GINConv(
Sequential(Linear(hidden_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(),
Linear(hidden_channels, hidden_channels), ReLU()))

self.conv5 = GINConv(
Sequential(Linear(hidden_channels, hidden_channels), BatchNorm1d(hidden_channels), ReLU(),
Linear(hidden_channels, hidden_channels), ReLU()))
Expand All @@ -53,3 +52,37 @@ def forward(self, x, edge_index, batch):
return x
# if NLLLoss used must use softmax
# return F.log_softmax(x, dim=-1)

def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
data = data.to(device)
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)
125 changes: 125 additions & 0 deletions phishGNN/models/mem_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d, LeakyReLU, Linear

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DeepGCNLayer, GATConv, MemPooling

# path = osp.join('data', 'TUD')
# dataset = TUDataset(path, name="MUTAG", use_node_attr=True)
# dataset.data.x = dataset.data.x[:, :-3] # only use non-binary features.
# dataset = dataset.shuffle()

# n = (len(dataset)) // 10
# test_dataset = dataset[:n]
# val_dataset = dataset[n:2 * n]
# train_dataset = dataset[2 * n:]

# test_loader = DataLoader(test_dataset, batch_size=20)
# val_loader = DataLoader(val_dataset, batch_size=20)
# train_loader = DataLoader(train_dataset, batch_size=20)


class MemPool(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device, dropout=0.5):
super().__init__()
self.device = device
self.to(device)
self.dropout = dropout

self.lin = Linear(in_channels, hidden_channels)

self.convs = torch.nn.ModuleList()
for i in range(2):
conv = GATConv(hidden_channels, hidden_channels, dropout=dropout)
norm = BatchNorm1d(hidden_channels)
act = LeakyReLU()
self.convs.append(
DeepGCNLayer(conv, norm, act, block='res+', dropout=dropout))

self.mem1 = MemPooling(hidden_channels, 80, heads=5, num_clusters=10)
self.mem2 = MemPooling(80, out_channels, heads=5, num_clusters=1)


def forward(self, x, edge_index, batch):
x = x.float()
x = self.lin(x)
for conv in self.convs:
x = conv(x, edge_index)

x, S1 = self.mem1(x, batch)
x = F.leaky_relu(x)
x = F.dropout(x, p=self.dropout)
x, S2 = self.mem2(x)

return (
F.log_softmax(x.squeeze(1), dim=-1),
MemPooling.kl_loss(S1) + MemPooling.kl_loss(S2),
)


def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

self.mem1.k.requires_grad = False
self.mem2.k.requires_grad = False
for data in train_loader:
optimizer.zero_grad()
data = data.to(self.device)
out = self(data.x, data.edge_index, data.batch)[0]
loss = F.nll_loss(out, data.y.long())
loss.backward()
optimizer.step()

kl_loss = 0.
self.mem1.k.requires_grad = True
self.mem2.k.requires_grad = True
optimizer.zero_grad()
for data in train_loader:
data = data.to(self.device)
kl_loss += self(data.x, data.edge_index, data.batch)[1]

kl_loss /= len(train_loader.dataset)
kl_loss.backward()
optimizer.step()

return kl_loss


@torch.no_grad()
def test(self, loader):
self.eval()
correct = 0
for data in loader:
data = data.to(self.device)
out = self(data.x, data.edge_index, data.batch)[0]
pred = out.argmax(dim=-1)
correct += int((pred == data.y).sum())
return correct / len(loader.dataset)


# patience = start_patience = 250
# test_acc = best_val_acc = 0.
# for epoch in range(1, 2001):
# train()
# val_acc = test(val_loader)
# if epoch % 500 == 0:
# optimizer.param_groups[0]['lr'] *= 0.5
# if best_val_acc < val_acc:
# patience = start_patience
# best_val_acc = val_acc
# test_acc = test(test_loader)
# else:
# patience -= 1
# print(f'Epoch {epoch:02d}, Val: {val_acc:.3f}, Test: {test_acc:.3f}')
# if patience <= 0:
# break
61 changes: 21 additions & 40 deletions phishGNN/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_geometric.loader import DataLoader

from visualization import visualize
from models import GCN, GIN
from models import GCN, GIN, MemPool


if __name__ == "__main__":
Expand All @@ -22,51 +22,32 @@

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GIN(
in_channels=dataset.num_features,
hidden_channels=32,
out_channels=dataset.num_classes,
).to(device)
# model = GCN(
# in_channels=dataset.num_node_features,
# hidden_channels=64,
# model = MemPool(
# in_channels=dataset.num_features,
# hidden_channels=32,
# out_channels=dataset.num_classes,
# device=device,
# )
# model = GIN(
# in_channels=dataset.num_features,
# hidden_channels=32,
# out_channels=dataset.num_classes,
# device=device,
# ).to(device)
model = GCN(
in_channels=dataset.num_node_features,
hidden_channels=64,
out_channels=dataset.num_classes,
device=device,
).to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=4e-5)
loss_fn = torch.nn.CrossEntropyLoss()

def train():
model.train()

total_loss = 0
for data in train_loader: # Iterate in batches over the training dataset.
data = data.to(device)
out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.
loss = loss_fn(out, data.y.long()) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


def test(loader):
model.eval()

correct = 0
for data in loader: # Iterate in batches over the training/test dataset.
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == data.y).sum()) # Check against ground-truth labels.

return correct / len(loader.dataset) # Derive ratio of correct predictions.


for epoch in range(1, 171):
loss = train()
train_acc = test(train_loader)
test_acc = test(test_loader)
loss = model.fit(train_loader, optimizer, loss_fn, device)
train_acc = model.test(train_loader)
test_acc = model.test(test_loader)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

0 comments on commit fe473a4

Please sign in to comment.