Skip to content

Commit

Permalink
🐛 fix the MLP using now only the root url features
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanBilot committed Mar 30, 2022
1 parent 0e031d7 commit 05f4149
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions phishGNN/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,51 @@ def __init__(

def forward(self, x, edge_index, batch):
x = x.float()
x = self.flatten(x)
x = nn.Linear(x.shape[1], 2 * self.hidden_channels)(x)
x = x[0]
# only use first features (root url features)
x = nn.Linear(x.shape[0], 2 * self.hidden_channels)(x)
x = self.relu(x)
x = self.lin(x)
x = self.relu(x)

x = self.pooling_fn(x, batch)
x = self.lin_out(x)
return x


def fit(
model,
train_loader,
optimizer,
loss_fn,
device,
dont_use_graphs: bool=False,
):
model.train()

total_loss = 0
for data in train_loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)
# only use first features (root url features)
loss = loss_fn(out, data.y[0].long())
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += float(loss) * data.num_graphs

return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(model, loader, device):
model.eval()

correct = 0
for data in loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)
# dim -1 instead of 1
pred = out.argmax(dim=-1)
correct += int((pred == data.y).sum())

return correct / len(loader.dataset)

0 comments on commit 05f4149

Please sign in to comment.