Skip to content

Commit

Permalink
✨ add MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanBilot committed Mar 29, 2022
1 parent 1e64e8d commit e6152ae
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
1 change: 1 addition & 0 deletions phishGNN/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .graphsage import GraphSAGE
from .cluster_gcn import ClusterGCN
from .mem_pool import MemPool
from .mlp import MLP
39 changes: 39 additions & 0 deletions phishGNN/models/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from torch import nn
from torch_geometric.nn import global_mean_pool


class MLP(nn.Module):
def __init__(
self,
in_channels=None,
hidden_channels=None,
out_channels=None,
pooling_fn=global_mean_pool,
device=None,
nb_layers=None,
):
super(MLP, self).__init__()

self.pooling_fn = pooling_fn
self.hidden_channels = hidden_channels
self.device = device
self.to(device)

self.flatten = nn.Flatten()
self.lin = nn.Linear(2 * hidden_channels, hidden_channels)
self.lin_out = nn.Linear(hidden_channels, out_channels)
self.relu = nn.ReLU()


def forward(self, x, edge_index, batch):
x = x.float()
x = self.flatten(x)
x = nn.Linear(x.shape[1], 2 * self.hidden_channels)(x)
x = self.relu(x)
x = self.lin(x)
x = self.relu(x)

x = self.pooling_fn(x, batch)
x = self.lin_out(x)
return x
3 changes: 2 additions & 1 deletion phishGNN/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch_geometric.loader import DataLoader

from visualization import visualize, plot_embeddings
from models import GCN_2, GCN_3, GIN, GAT, GraphSAGE, ClusterGCN, MemPool
from models import GCN_2, GCN_3, GIN, GAT, MLP, GraphSAGE, ClusterGCN, MemPool
from utils.utils import mean_std_error


Expand Down Expand Up @@ -66,6 +66,7 @@ def train(should_plot_embeddings: bool):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = [
MLP,
GAT,
GIN,
GCN_2,
Expand Down

0 comments on commit e6152ae

Please sign in to comment.