From 8238302422b0f1cf06bf8910ab998e8926dd872a Mon Sep 17 00:00:00 2001 From: Remigiusz Poltorak Date: Mon, 18 Jan 2021 01:41:01 +0100 Subject: [PATCH] Add accuracy metric to basic model --- mask_detector/model.py | 5 ----- mask_detector/train.py | 18 +++++++++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mask_detector/model.py b/mask_detector/model.py index 83829fe..c15e4ff 100644 --- a/mask_detector/model.py +++ b/mask_detector/model.py @@ -24,14 +24,9 @@ def __init__(self): ) def forward(self, x): - print(x.size()) x = self.convLayers1(x) - print(x.shape) x = self.convLayers2(x) - print(x.shape) x = x.view(-1, 64*8*8) - print(x.shape) x = self.linearLayers(x) - print(x.shape) return x diff --git a/mask_detector/train.py b/mask_detector/train.py index 99ed91e..ede8224 100644 --- a/mask_detector/train.py +++ b/mask_detector/train.py @@ -1,9 +1,12 @@ from argparse import ArgumentParser +import torch from torch.nn.functional import binary_cross_entropy from torch.utils.data import DataLoader, random_split from torch.optim import Adam from pytorch_lightning import LightningModule, Trainer, seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from sklearn.metrics import accuracy_score from model import BasicCNN from dataset import MaskDataset @@ -36,16 +39,16 @@ def validation_step(self, batch, batch_idx): self.log('valid_loss', loss, on_step=True) - return loss - def test_step(self, batch, batch_idx): x, y = batch out = self.net(x) loss = binary_cross_entropy(out, y) - self.log('test_loss', loss) + _, out = torch.max(out, dim=1) + val_acc = accuracy_score(out.cpu(), y.cpu()) + val_acc = torch.tensor(val_acc) - return loss + return {'test_loss': loss, 'test_acc': val_acc} def configure_optimizers(self): # self.hparams available because we called self.save_hyperparameters() @@ -75,7 +78,12 @@ def cli_main(): # ------------ # training # ------------ - trainer = Trainer(max_epochs=1) + checkpoint_callback = ModelCheckpoint( + verbose=True, + monitor='test_acc', + mode='max' + ) + trainer = Trainer(max_epochs=1, checkpoint_callback=checkpoint_callback) trainer.fit(model, train_loader, val_loader) # ------------