Skip to content

Commit

Permalink
🐛 add some fixs
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanBilot committed Mar 28, 2022
1 parent 3a9a4f8 commit 58445d8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion phishGNN/models/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def fit(


@torch.no_grad()
def test(self, loader):
def test(self, loader, device):
self.eval()
correct = 0
for data in loader:
Expand Down
2 changes: 1 addition & 1 deletion phishGNN/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def predict(url: str, weights_file: str) -> int:
parser.add_argument('url', type=str, help='the url to predict (phishing/benign)')
parser.add_argument('pkl_file', type=str, default="GCN_3_global_mean_pool_32.pkl",
help='the path to the model weights (.pkl)')
args = parser.parse_args()
args, _ = parser.parse_known_args()

pred = predict(args.url, args.weights_file)

Expand Down
24 changes: 13 additions & 11 deletions phishGNN/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def fit(


@torch.no_grad()
def test(model, loader):
def test(model, loader, device):
model.eval()

correct = 0
Expand Down Expand Up @@ -89,7 +89,7 @@ def train(should_plot_embeddings: bool):

lr = 0.01
weight_decay = 4e-5
epochs = 20
epochs = 1

accuracies = defaultdict(lambda: [])
for (model, pooling, neurons) in itertools.product(
Expand Down Expand Up @@ -117,11 +117,11 @@ def train(should_plot_embeddings: bool):
else:
loss = fit(model, train_loader, optimizer, loss_fn, device)
if hasattr(model, 'test'):
train_acc = model.test(train_loader)
test_acc = model.test(test_loader)
train_acc = model.test(train_loader, device)
test_acc = model.test(test_loader, device)
else:
train_acc = test(model, train_loader)
test_acc = test(model, test_loader)
train_acc = test(model, train_loader, device)
test_acc = test(model, test_loader, device)

train_accs.append(train_acc)
test_accs.append(test_acc)
Expand All @@ -134,12 +134,14 @@ def train(should_plot_embeddings: bool):
}
})

with open('training.logs', 'w') as logs:
out_path = os.path.join("weights", f"{epochs}_epochs")
os.makedirs(out_path, exist_ok=True)

with open(os.path.join(out_path, f"accuracies_{epochs}_epochs.json"), 'w') as logs:
formatted = json.dumps(accuracies, indent=2)
logs.write(formatted)

os.makedirs("weights", exist_ok=True)
torch.save(model, f"weights/{label}.pkl")
torch.save(model, f"{out_path}/{label}.pkl")

if should_plot_embeddings:
loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
Expand All @@ -150,8 +152,8 @@ def train(should_plot_embeddings: bool):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--plot-embeddings', action="store_false",
parser.add_argument('--plot-embeddings', action="store_true",
help='whether to save the embeddings in a png file during training or not')
args = parser.parse_args()
args, _ = parser.parse_known_args()

train(args.plot_embeddings)

0 comments on commit 58445d8

Please sign in to comment.