diff --git a/.gitignore b/.gitignore index 3b6ea01..2b0283e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ .pytest_cache/ *.ipynb -.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/losses.py b/losses.py index 43cb446..060073f 100644 --- a/losses.py +++ b/losses.py @@ -1,13 +1,12 @@ -""" -Useful definitions of common image segmentation losses. +"""Common image segmentation losses. """ import torch -import torch.nn as nn -import torch.nn.functional as F + +from torch.nn import functional as F -def bce_loss(true, logits, pos_weight): +def bce_loss(true, logits, pos_weight=None): """Computes the weighted binary cross-entropy loss. Args: @@ -21,10 +20,10 @@ def bce_loss(true, logits, pos_weight): Returns: bce_loss: the weighted binary cross-entropy loss. """ - true_ = true.float() - logits_ = logits.float() bce_loss = F.binary_cross_entropy_with_logits( - logits_, true_, pos_weight=pos_weight, + logits.float(), + true.float(), + pos_weight=pos_weight, ) return bce_loss @@ -43,11 +42,11 @@ def ce_loss(true, logits, weights, ignore=255): Returns: ce_loss: the weighted multi-class cross-entropy loss. """ - true_ = true.long() - logits_ = logits.float() ce_loss = F.cross_entropy( - logits_, true, - ignore_index=ignore, weight=weights + logits.float(), + true.long(), + ignore_index=ignore, + weight=weights, ) return ce_loss @@ -81,7 +80,8 @@ def dice_loss(true, logits, eps=1e-7): else: true_1_hot = torch.eye(num_classes)[true.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() - probas = F.softmax(probas, dim=1) + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) dims = (0,) + tuple(range(2, true.ndimension())) intersection = torch.sum(probas * true_1_hot, dims) cardinality = torch.sum(probas + true_1_hot, dims) @@ -119,6 +119,7 @@ def jaccard_loss(true, logits, eps=1e-7): true_1_hot = torch.eye(num_classes)[true.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() probas = F.softmax(probas, dim=1) + true_1_hot = true_1_hot.type(logits.type()) dims = (0,) + tuple(range(2, true.ndimension())) intersection = torch.sum(probas * true_1_hot, dims) cardinality = torch.sum(probas + true_1_hot, dims) @@ -163,12 +164,13 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7): true_1_hot = torch.eye(num_classes)[true.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() probas = F.softmax(probas, dim=1) + true_1_hot = true_1_hot.type(logits.type()) dims = (0,) + tuple(range(2, true.ndimension())) intersection = torch.sum(probas * true_1_hot, dims) fps = torch.sum(probas * (1 - true_1_hot), dims) fns = torch.sum((1 - probas) * true_1_hot, dims) num = intersection - denom = intersection + alpha*fps + beta*fns + denom = intersection + (alpha * fps) + (beta * fns) tversky_loss = (num / (denom + eps)).mean() return (1 - tversky_loss) diff --git a/metrics.py b/metrics.py index 8adb760..417e98d 100644 --- a/metrics.py +++ b/metrics.py @@ -1,11 +1,11 @@ -""" -Useful definitions of common ml metrics. +"""Common image segmentation metrics. """ import torch from utils import nanmean + EPS = 1e-10 diff --git a/models/base.py b/models/base.py index 3ca4111..a1123de 100644 --- a/models/base.py +++ b/models/base.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn +from torch.nn import functional as F class BaseModel(nn.Module): diff --git a/models/mnist.py b/models/mnist.py index f58ecb8..14a30f6 100644 --- a/models/mnist.py +++ b/models/mnist.py @@ -2,8 +2,8 @@ From https://github.com/pytorch/examples/blob/master/mnist/main.py """ import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn +from torch.nn import functional as F from .base import BaseModel diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index 23ecdff..0000000 --- a/tests/test_models.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch -import pytest - -from models import MnistConvNet - - -def test_num_params(): - net = MnistConvNet() - print("# of params: {:,}".format(net.num_params)) -