Skip to content
This repository has been archived by the owner on Aug 7, 2023. It is now read-only.

Add CLI params support + MobileNetV2 model + small refactoring #9

Merged
merged 10 commits into from
Mar 8, 2021
Prev Previous commit
Next Next commit
Add recall metric & refactor logging in train,validation and test steps
  • Loading branch information
Remigiusz Poltorak committed Jan 31, 2021
commit 22bed223347f4e55fc66485a8c678d7dc2a0b9af
24 changes: 18 additions & 6 deletions mask_detector/mask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics import Accuracy, Recall
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.metrics import accuracy_score

Expand All @@ -20,7 +20,7 @@ def __init__(self, net, learning_rate=1e-3):
super().__init__()
self.net = net
self.learning_rate = learning_rate
self.accuracy = Accuracy()
self.recall = Recall()

def forward(self, x):
return self.net(x)
Expand All @@ -29,22 +29,34 @@ def training_step(self, batch, batch_idx):
x, y = batch
out = self.net(x)
loss = binary_cross_entropy(out, y)
recall = self.recall(out, y)

return {'loss': loss, 'accuracy': self.accuracy(out, y)}
self.log('train_loss', recall, on_step=False, on_epoch=True)
self.log('train_recall', recall, on_step=False, on_epoch=True)

return loss

def validation_step(self, batch, batch_idx):
x, y = batch
out = self.net(x)
loss = binary_cross_entropy(out, y)
recall = self.recall(out, y)

self.log('val_loss', recall, on_step=False, on_epoch=True)
self.log('val_recall', recall, on_step=False, on_epoch=True)

return {'loss': loss, 'accuracy': self.accuracy(out, y)}
return loss

def test_step(self, batch, batch_idx):
x, y = batch
out = self.net(x)
loss = binary_cross_entropy(out, y)
recall = self.recall(out, y)

self.log('test_loss', recall, on_step=False, on_epoch=True)
self.log('test_recall', recall, on_step=False, on_epoch=True)

return {'loss': loss, 'accuracy': self.accuracy(out, y)}
return loss

def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
Expand Down Expand Up @@ -94,7 +106,7 @@ def cli_main():
# training
# ------------
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
result = trainer.fit(model, train_loader, val_loader)

# ------------
# testing
Expand Down