Skip to content
This repository has been archived by the owner on Aug 7, 2023. It is now read-only.

Commit

Permalink
Add accuracy metric to basic model
Browse files Browse the repository at this point in the history
  • Loading branch information
Remigiusz Poltorak committed Jan 18, 2021
1 parent 89f7ad0 commit 8238302
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
5 changes: 0 additions & 5 deletions mask_detector/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,9 @@ def __init__(self):
)

def forward(self, x):
print(x.size())
x = self.convLayers1(x)
print(x.shape)
x = self.convLayers2(x)
print(x.shape)
x = x.view(-1, 64*8*8)
print(x.shape)
x = self.linearLayers(x)
print(x.shape)

return x
18 changes: 13 additions & 5 deletions mask_detector/train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from argparse import ArgumentParser

import torch
from torch.nn.functional import binary_cross_entropy
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.metrics import accuracy_score

from model import BasicCNN
from dataset import MaskDataset
Expand Down Expand Up @@ -36,16 +39,16 @@ def validation_step(self, batch, batch_idx):

self.log('valid_loss', loss, on_step=True)

return loss

def test_step(self, batch, batch_idx):
x, y = batch
out = self.net(x)
loss = binary_cross_entropy(out, y)

self.log('test_loss', loss)
_, out = torch.max(out, dim=1)
val_acc = accuracy_score(out.cpu(), y.cpu())
val_acc = torch.tensor(val_acc)

return loss
return {'test_loss': loss, 'test_acc': val_acc}

def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
Expand Down Expand Up @@ -75,7 +78,12 @@ def cli_main():
# ------------
# training
# ------------
trainer = Trainer(max_epochs=1)
checkpoint_callback = ModelCheckpoint(
verbose=True,
monitor='test_acc',
mode='max'
)
trainer = Trainer(max_epochs=1, checkpoint_callback=checkpoint_callback)
trainer.fit(model, train_loader, val_loader)

# ------------
Expand Down

0 comments on commit 8238302

Please sign in to comment.