Skip to content

Commit

Permalink
♻️ move test() and fit() methods to training
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanBilot committed Mar 28, 2022
1 parent ba889fa commit 349a3ee
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 302 deletions.
45 changes: 4 additions & 41 deletions phishGNN/models/cluster_gcn.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score

from torch_geometric.data import Batch
from torch_geometric.datasets import PPI
from torch_geometric.loader import ClusterData, ClusterLoader, DataLoader
from torch_geometric.nn import BatchNorm, SAGEConv
from torch_geometric.nn import global_mean_pool


class ClusterGCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device, num_layers=6):
super().__init__()

if hidden_channels is None:
hidden_channels = 32

self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
Expand All @@ -39,37 +36,3 @@ def forward(self, x, edge_index, batch):
self.embeddings = x

return x

def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
data = data.to(device)
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)
71 changes: 3 additions & 68 deletions phishGNN/models/gat.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,20 @@
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_mean_pool

# dataset = 'Cora'
# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
# dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
# data = dataset[0]


class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
super(GAT, self).__init__()

if hidden_channels is None:
hidden_channels = 8

self.device = device
self.to(device)

self.conv1 = GATConv(in_channels, hidden_channels, heads=hidden_channels, dropout=0.6)
# On the Pubmed dataset, use heads=8 in conv2.
self.conv2 = GATConv(
hidden_channels * hidden_channels, out_channels, heads=1, concat=False, dropout=0.6)
self.embeddings = None
Expand All @@ -38,60 +30,3 @@ def forward(self, x, edge_index, batch):
self.embeddings = x

return x
# return F.log_softmax(x, dim=-1)


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = Net(dataset.num_features, dataset.num_classes).to(device)
# data = data.to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)

# @torch.no_grad()
# def test(data):
# self.eval()
# out, accs = self(data.x, data.edge_index), []
# for _, mask in data('train_mask', 'val_mask', 'test_mask'):
# acc = float((out[mask].argmax(-1) == data.y[mask]).sum() / mask.sum())
# accs.append(acc)
# return accs


# for epoch in range(1, 201):
# train(data)
# train_acc, val_acc, test_acc = test(data)
# print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
# f'Test: {test_acc:.4f}')
37 changes: 3 additions & 34 deletions phishGNN/models/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
super(GCN, self).__init__()

if hidden_channels is None:
hidden_channels = 32

self.device = device
self.to(device)

Expand All @@ -34,37 +37,3 @@ def forward(self, x, edge_index, batch):
x = self.lin(x)

return x

def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
data = data.to(device)
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)
42 changes: 4 additions & 38 deletions phishGNN/models/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Dropout, Linear, ReLU, Sequential
import torch_geometric.transforms as T

from torch.nn import BatchNorm1d, Dropout, Linear, ReLU, Sequential
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool
Expand All @@ -15,6 +14,9 @@ class GIN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
super().__init__()

if hidden_channels is None:
hidden_channels = 32

self.device = device
self.to(device)

Expand Down Expand Up @@ -54,39 +56,3 @@ def forward(self, x, edge_index, batch):
x = self.lin2(x)

return x
# if NLLLoss used must use softmax
# return F.log_softmax(x, dim=-1)

def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
data = data.to(device)
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)
37 changes: 3 additions & 34 deletions phishGNN/models/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device):
super(GraphSAGE, self).__init__()

if hidden_channels is None:
hidden_channels = 16

self.device = device
self.to(device)

Expand All @@ -27,37 +30,3 @@ def forward(self, x, edge_index, batch):
self.embeddings = x

return x

def fit(
self,
train_loader,
optimizer,
loss_fn,
device,
):
self.train()

total_loss = 0
for data in train_loader:
data = data.to(device)
out = self(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(self, loader):
self.eval()

correct = 0
for data in loader:
out = self(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)
41 changes: 4 additions & 37 deletions phishGNN/models/mem_pool.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d, LeakyReLU, Linear

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DeepGCNLayer, GATConv, MemPooling

# path = osp.join('data', 'TUD')
# dataset = TUDataset(path, name="MUTAG", use_node_attr=True)
# dataset.data.x = dataset.data.x[:, :-3] # only use non-binary features.
# dataset = dataset.shuffle()

# n = (len(dataset)) // 10
# test_dataset = dataset[:n]
# val_dataset = dataset[n:2 * n]
# train_dataset = dataset[2 * n:]

# test_loader = DataLoader(test_dataset, batch_size=20)
# val_loader = DataLoader(val_dataset, batch_size=20)
# train_loader = DataLoader(train_dataset, batch_size=20)


class MemPool(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, device, dropout=0.5):
super().__init__()

if hidden_channels is None:
hidden_channels = 32

self.device = device
self.to(device)
self.dropout = dropout
Expand Down Expand Up @@ -107,21 +92,3 @@ def test(self, loader):
pred = out.argmax(dim=-1)
correct += int((pred == data.y).sum())
return correct / len(loader.dataset)


# patience = start_patience = 250
# test_acc = best_val_acc = 0.
# for epoch in range(1, 2001):
# train()
# val_acc = test(val_loader)
# if epoch % 500 == 0:
# optimizer.param_groups[0]['lr'] *= 0.5
# if best_val_acc < val_acc:
# patience = start_patience
# best_val_acc = val_acc
# test_acc = test(test_loader)
# else:
# patience -= 1
# print(f'Epoch {epoch:02d}, Val: {val_acc:.3f}, Test: {test_acc:.3f}')
# if patience <= 0:
# break
Loading

0 comments on commit 349a3ee

Please sign in to comment.