Skip to content

Commit

Permalink
🎨 add default values for model parameters and implement combinatorial…
Browse files Browse the repository at this point in the history
… model benchmarking
  • Loading branch information
TristanBilot committed Mar 28, 2022
1 parent d07e6ac commit a7ee8e8
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 35 deletions.
17 changes: 13 additions & 4 deletions phishGNN/models/cluster_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@


class ClusterGCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device, num_layers=6):
def __init__(
self,
in_channels=None,
hidden_channels=32,
out_channels=None,
device=None,
pooling_fn=global_mean_pool,
num_layers=6,
):
super().__init__()

if hidden_channels is None:
hidden_channels = 32
self.pooling_fn = pooling_fn
self.device = device
self.to(device)

self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
Expand All @@ -32,7 +41,7 @@ def forward(self, x, edge_index, batch):
x = F.dropout(x, p=0.2, training=self.training)

x = self.convs[-1](x, edge_index)
x = global_mean_pool(x, batch)
x = self.pooling_fn(x, batch)
self.embeddings = x

return x
15 changes: 10 additions & 5 deletions phishGNN/models/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@


class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
def __init__(
self,
in_channels=None,
hidden_channels=8,
out_channels=None,
pooling_fn=global_mean_pool,
device=None,
):
super(GAT, self).__init__()

if hidden_channels is None:
hidden_channels = 8

self.pooling_fn = pooling_fn
self.device = device
self.to(device)

Expand All @@ -26,7 +31,7 @@ def forward(self, x, edge_index, batch):
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)

x = global_mean_pool(x, batch)
x = self.pooling_fn(x, batch)
self.embeddings = x

return x
15 changes: 10 additions & 5 deletions phishGNN/models/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@


class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
def __init__(
self,
in_channels=None,
hidden_channels=32,
out_channels=None,
pooling_fn=global_mean_pool,
device=None,
):
super(GCN, self).__init__()

if hidden_channels is None:
hidden_channels = 32

self.pooling_fn = pooling_fn
self.device = device
self.to(device)

Expand All @@ -30,7 +35,7 @@ def forward(self, x, edge_index, batch):
x = x.relu()
x = self.conv3(x, edge_index)

x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
x = self.pooling_fn(x, batch)
self.embeddings = x

x = F.dropout(x, p=0.5, training=self.training)
Expand Down
15 changes: 10 additions & 5 deletions phishGNN/models/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@


class GIN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
def __init__(
self,
in_channels=None,
hidden_channels=32,
out_channels=None,
pooling_fn=global_add_pool,
device=None,
):
super().__init__()

if hidden_channels is None:
hidden_channels = 32

self.pooling_fn = pooling_fn
self.device = device
self.to(device)

Expand Down Expand Up @@ -48,7 +53,7 @@ def forward(self, x, edge_index, batch):
x = self.conv4(x, edge_index)
x = self.conv5(x, edge_index)

x = global_add_pool(x, batch)
x = self.pooling_fn(x, batch)
self.embeddings = x

x = self.lin1(x).relu()
Expand Down
15 changes: 10 additions & 5 deletions phishGNN/models/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@


class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
def __init__(
self,
in_channels=None,
hidden_channels=16,
out_channels=None,
pooling_fn=global_mean_pool,
device=None,
):
super(GraphSAGE, self).__init__()

if hidden_channels is None:
hidden_channels = 16

self.pooling_fn = pooling_fn
self.device = device
self.to(device)

Expand All @@ -26,7 +31,7 @@ def forward(self, x, edge_index, batch):
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)

x = global_mean_pool(x, batch)
x = self.pooling_fn(x, batch)
self.embeddings = x

return x
13 changes: 9 additions & 4 deletions phishGNN/models/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@


class MemPool(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device, dropout=0.5):
def __init__(
self,
in_channels=None,
hidden_channels=32,
out_channels=None,
pooling_fn=None,
device=None,
dropout=0.5,
):
super().__init__()

if hidden_channels is None:
hidden_channels = 32

self.device = device
self.to(device)
self.dropout = dropout
Expand Down
45 changes: 38 additions & 7 deletions phishGNN/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import itertools
import json
from collections import defaultdict
from pprint import pprint

import torch
import torch_geometric.nn as nn
from dataset import PhishingDataset
from torch_geometric.loader import DataLoader

Expand Down Expand Up @@ -66,19 +71,36 @@ def test(loader):
MemPool,
]

poolings = [
nn.global_mean_pool,
nn.global_max_pool,
nn.global_add_pool,
]

hidden_neurons = [
32,
64,
128,
]

lr = 0.01
weight_decay = 4e-5
epochs = 2
epochs = 10

accuracies = {}
for model in models:
accuracies = defaultdict(lambda: [])
for (model, pooling, neurons) in itertools.product(
models,
poolings,
hidden_neurons,
):
model = model(
in_channels=dataset.num_features,
hidden_channels=None,
hidden_channels=neurons,
out_channels=dataset.num_classes,
pooling_fn=pooling,
device=device,
)
# print(model)
print(f"\n{model.__class__.__name__}, {pooling.__name__}, {neurons}")

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
loss_fn = torch.nn.CrossEntropyLoss()
Expand All @@ -96,9 +118,18 @@ def test(loader):
test_accs.append(test_acc)
print(f'Epoch: {(epoch+1):03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

accuracies[model.__class__] = (mean_std_error(train_accs), mean_std_error(test_accs))
accuracies[model.__class__.__name__].append({
f'{pooling.__name__}, {neurons}': {
'train': mean_std_error(train_accs),
'test': mean_std_error(test_accs),
}
})

with open('training.logs', 'w') as logs:
formatted = json.dumps(accuracies, indent=2)
logs.write(formatted)

# loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# plot_embeddings(model, loader)

print(accuracies)
pprint(accuracies)

0 comments on commit a7ee8e8

Please sign in to comment.