diff --git a/losses.py b/losses.py index f56ac23..43cb446 100644 --- a/losses.py +++ b/losses.py @@ -1,5 +1,5 @@ """ -Useful definitions of common ml losses. +Useful definitions of common image segmentation losses. """ import torch @@ -86,10 +86,10 @@ def dice_loss(true, logits, eps=1e-7): intersection = torch.sum(probas * true_1_hot, dims) cardinality = torch.sum(probas + true_1_hot, dims) dice_loss = (2. * intersection / (cardinality + eps)).mean() - return (-1 * dice_loss) + return (1 - dice_loss) -def jaccard_loss(true, pred, eps=1e-7): +def jaccard_loss(true, logits, eps=1e-7): """Computes the Jaccard loss, a.k.a the IoU loss. Note that PyTorch optimizers minimize a loss. In this @@ -124,7 +124,53 @@ def jaccard_loss(true, pred, eps=1e-7): cardinality = torch.sum(probas + true_1_hot, dims) union = cardinality - intersection jacc_loss = (intersection / (union + eps)).mean() - return (-1 * jacc_loss) + return (1 - jacc_loss) + + +def tversky_loss(true, logits, alpha, beta, eps=1e-7): + """Computes the Tversky loss [1]. + + Args: + true: a tensor of shape [B, H, W] or [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + alpha: controls the penalty for false positives. + beta: controls the penalty for false negatives. + eps: added to the denominator for numerical stability. + + Returns: + tversky_loss: the Tversky loss. + + Notes: + alpha = beta = 0.5 => dice coeff + alpha = beta = 1 => tanimoto coeff + alpha + beta = 1 => F beta coeff + + References: + [1]: https://arxiv.org/abs/1706.05721 + """ + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(probas, dim=1) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + fps = torch.sum(probas * (1 - true_1_hot), dims) + fns = torch.sum((1 - probas) * true_1_hot, dims) + num = intersection + denom = intersection + alpha*fps + beta*fns + tversky_loss = (num / (denom + eps)).mean() + return (1 - tversky_loss) def ce_dice(true, pred, log=False, w1=1, w2=1):