Skip to content
This repository has been archived by the owner on Jan 4, 2023. It is now read-only.

Commit

Permalink
added tversky loss
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzakka committed Sep 18, 2018
1 parent 877eb61 commit 9921b50
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions losses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Useful definitions of common ml losses.
Useful definitions of common image segmentation losses.
"""

import torch
Expand Down Expand Up @@ -86,10 +86,10 @@ def dice_loss(true, logits, eps=1e-7):
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
dice_loss = (2. * intersection / (cardinality + eps)).mean()
return (-1 * dice_loss)
return (1 - dice_loss)


def jaccard_loss(true, pred, eps=1e-7):
def jaccard_loss(true, logits, eps=1e-7):
"""Computes the Jaccard loss, a.k.a the IoU loss.
Note that PyTorch optimizers minimize a loss. In this
Expand Down Expand Up @@ -124,7 +124,53 @@ def jaccard_loss(true, pred, eps=1e-7):
cardinality = torch.sum(probas + true_1_hot, dims)
union = cardinality - intersection
jacc_loss = (intersection / (union + eps)).mean()
return (-1 * jacc_loss)
return (1 - jacc_loss)


def tversky_loss(true, logits, alpha, beta, eps=1e-7):
"""Computes the Tversky loss [1].
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
alpha: controls the penalty for false positives.
beta: controls the penalty for false negatives.
eps: added to the denominator for numerical stability.
Returns:
tversky_loss: the Tversky loss.
Notes:
alpha = beta = 0.5 => dice coeff
alpha = beta = 1 => tanimoto coeff
alpha + beta = 1 => F beta coeff
References:
[1]: https://arxiv.org/abs/1706.05721
"""
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(probas, dim=1)
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
fps = torch.sum(probas * (1 - true_1_hot), dims)
fns = torch.sum((1 - probas) * true_1_hot, dims)
num = intersection
denom = intersection + alpha*fps + beta*fns
tversky_loss = (num / (denom + eps)).mean()
return (1 - tversky_loss)


def ce_dice(true, pred, log=False, w1=1, w2=1):
Expand Down

0 comments on commit 9921b50

Please sign in to comment.