From 4d6c7f9017ad83e12f485505fbd45ae5073be985 Mon Sep 17 00:00:00 2001 From: Kevin Date: Sat, 15 Sep 2018 12:58:08 -0700 Subject: [PATCH] multi-class support for losses. Added jaccard loss. --- losses.py | 100 ++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 26 deletions(-) diff --git a/losses.py b/losses.py index 2abe377..1461f01 100644 --- a/losses.py +++ b/losses.py @@ -11,8 +11,8 @@ def bce_loss(true, logits, pos_weight): """Computes the weighted binary cross-entropy loss. Args: - true: a tensor of shape [B, H, W] or [B, 1, H, W]. - logits: a tensor of shape [B, C, H, W]. Corresponds to + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, 1, H, W]. Corresponds to the raw output or logits of the model. pos_weight: a scalar representing the weight attributed to the positive class. This is especially useful for @@ -21,8 +21,10 @@ def bce_loss(true, logits, pos_weight): Returns: bce_loss: the weighted binary cross-entropy loss. """ + true_ = true.float() + logits_ = logits.float() bce_loss = F.binary_cross_entropy_with_logits( - logits, true.float(), pos_weight=pos_weight, + logits_, true_, pos_weight=pos_weight, ) return bce_loss @@ -31,52 +33,98 @@ def ce_loss(true, logits, weights, ignore=255): """Computes the weighted multi-class cross-entropy loss. Args: - true: a tensor of shape [B, H, W] or [B, 1, H, W]. + true: a tensor of shape [B, 1, H, W]. logits: a tensor of shape [B, C, H, W]. Corresponds to the raw output or logits of the model. - weight: a tensor of shape [2,]. The weights attributed + weight: a tensor of shape [C,]. The weights attributed to each class. ignore: the class index to ignore. Returns: - ce_loss: the weighted binary cross-entropy loss. + ce_loss: the weighted multi-class cross-entropy loss. """ + true_ = true.long() + logits_ = logits.float() ce_loss = F.cross_entropy( - logits, true.squeeze(), + logits_, true, ignore_index=ignore, weight=weights ) return ce_loss -def dice_loss(true, logits, log=False, force_positive=False): - """Computes the binary dice loss. +def dice_loss(true, logits, eps=1e-7): + """Computes the Sørensen–Dice loss. + + Note that PyTorch optimizers minimize a loss. In this + case, we would like to maximize the dice loss so we + return the negated dice loss. Args: - true: a tensor of shape [B, H, W] or [B, 1, H, W]. + true: a tensor of shape [B, 1, H, W]. logits: a tensor of shape [B, C, H, W]. Corresponds to the raw output or logits of the model. - log: whether to return the loss in log space. - force_positive: whether to add 1 to the loss to prevent - it from becoming negative. + eps: Returns: - dice_loss: the binary dice loss. + dice_loss: the Sørensen–Dice loss. """ - eps = 1e-15 - dice_output = torch.sigmoid(logits) - dice_target = (true == 1).float() - intersection = (dice_output * dice_target).sum() - union = dice_output.sum() + dice_target.sum() + eps - dice_loss = 2 * intersection / union - if force_positive: - return (1 - dice_loss) - if log: - dice_loss = torch.log(dice_loss) + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze()] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze()] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(probas, dim=1) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + cardinality = torch.sum(probas + true_1_hot, dims) + dice_loss = (2. * intersection / (cardinality + eps)).mean() return (-1 * dice_loss) -def jaccard_loss(true, pred): - pass +def jaccard_loss(true, pred, eps=1e-7): + """Computes the Jaccard loss, a.k.a the IoU loss. + + Note that PyTorch optimizers minimize a loss. In this + case, we would like to maximize the dice loss so we + return the negated dice loss. + + Args: + true: a tensor of shape [B, H, W] or [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + eps: + + Returns: + jacc_loss: the Jaccard loss. + """ + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze()] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze()] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(probas, dim=1) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + cardinality = torch.sum(probas + true_1_hot, dims) + union = cardinality - intersection + jacc_loss = (intersection / (union + eps)).mean() + return (-1 * jacc_loss) def ce_dice(true, pred, log=False, w1=1, w2=1):