Skip to content
This repository has been archived by the owner on Jan 4, 2023. It is now read-only.

Commit

Permalink
loss fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzakka committed Feb 16, 2019
1 parent 9921b50 commit eb56b4b
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 31 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
.pytest_cache/
*.ipynb
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
30 changes: 16 additions & 14 deletions losses.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""
Useful definitions of common image segmentation losses.
"""Common image segmentation losses.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import functional as F


def bce_loss(true, logits, pos_weight):
def bce_loss(true, logits, pos_weight=None):
"""Computes the weighted binary cross-entropy loss.
Args:
Expand All @@ -21,10 +20,10 @@ def bce_loss(true, logits, pos_weight):
Returns:
bce_loss: the weighted binary cross-entropy loss.
"""
true_ = true.float()
logits_ = logits.float()
bce_loss = F.binary_cross_entropy_with_logits(
logits_, true_, pos_weight=pos_weight,
logits.float(),
true.float(),
pos_weight=pos_weight,
)
return bce_loss

Expand All @@ -43,11 +42,11 @@ def ce_loss(true, logits, weights, ignore=255):
Returns:
ce_loss: the weighted multi-class cross-entropy loss.
"""
true_ = true.long()
logits_ = logits.float()
ce_loss = F.cross_entropy(
logits_, true,
ignore_index=ignore, weight=weights
logits.float(),
true.long(),
ignore_index=ignore,
weight=weights,
)
return ce_loss

Expand Down Expand Up @@ -81,7 +80,8 @@ def dice_loss(true, logits, eps=1e-7):
else:
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(probas, dim=1)
probas = F.softmax(logits, dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
Expand Down Expand Up @@ -119,6 +119,7 @@ def jaccard_loss(true, logits, eps=1e-7):
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(probas, dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
Expand Down Expand Up @@ -163,12 +164,13 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7):
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(probas, dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
fps = torch.sum(probas * (1 - true_1_hot), dims)
fns = torch.sum((1 - probas) * true_1_hot, dims)
num = intersection
denom = intersection + alpha*fps + beta*fns
denom = intersection + (alpha * fps) + (beta * fns)
tversky_loss = (num / (denom + eps)).mean()
return (1 - tversky_loss)

Expand Down
4 changes: 2 additions & 2 deletions metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
Useful definitions of common ml metrics.
"""Common image segmentation metrics.
"""

import torch

from utils import nanmean


EPS = 1e-10


Expand Down
4 changes: 2 additions & 2 deletions models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn import functional as F


class BaseModel(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions models/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
From https://github.com/pytorch/examples/blob/master/mnist/main.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn import functional as F

from .base import BaseModel

Expand Down
Empty file removed tests/__init__.py
Empty file.
10 changes: 0 additions & 10 deletions tests/test_models.py

This file was deleted.

0 comments on commit eb56b4b

Please sign in to comment.