Skip to content
This repository has been archived by the owner on Jan 4, 2023. It is now read-only.

Commit

Permalink
multi-class support for losses. Added jaccard loss.
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzakka committed Sep 15, 2018
1 parent fba50e7 commit 4d6c7f9
Showing 1 changed file with 74 additions and 26 deletions.
100 changes: 74 additions & 26 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def bce_loss(true, logits, pos_weight):
"""Computes the weighted binary cross-entropy loss.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, 1, H, W]. Corresponds to
the raw output or logits of the model.
pos_weight: a scalar representing the weight attributed
to the positive class. This is especially useful for
Expand All @@ -21,8 +21,10 @@ def bce_loss(true, logits, pos_weight):
Returns:
bce_loss: the weighted binary cross-entropy loss.
"""
true_ = true.float()
logits_ = logits.float()
bce_loss = F.binary_cross_entropy_with_logits(
logits, true.float(), pos_weight=pos_weight,
logits_, true_, pos_weight=pos_weight,
)
return bce_loss

Expand All @@ -31,52 +33,98 @@ def ce_loss(true, logits, weights, ignore=255):
"""Computes the weighted multi-class cross-entropy loss.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
weight: a tensor of shape [2,]. The weights attributed
weight: a tensor of shape [C,]. The weights attributed
to each class.
ignore: the class index to ignore.
Returns:
ce_loss: the weighted binary cross-entropy loss.
ce_loss: the weighted multi-class cross-entropy loss.
"""
true_ = true.long()
logits_ = logits.float()
ce_loss = F.cross_entropy(
logits, true.squeeze(),
logits_, true,
ignore_index=ignore, weight=weights
)
return ce_loss


def dice_loss(true, logits, log=False, force_positive=False):
"""Computes the binary dice loss.
def dice_loss(true, logits, eps=1e-7):
"""Computes the Sørensen–Dice loss.
Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the dice loss so we
return the negated dice loss.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
log: whether to return the loss in log space.
force_positive: whether to add 1 to the loss to prevent
it from becoming negative.
eps:
Returns:
dice_loss: the binary dice loss.
dice_loss: the Sørensen–Dice loss.
"""
eps = 1e-15
dice_output = torch.sigmoid(logits)
dice_target = (true == 1).float()
intersection = (dice_output * dice_target).sum()
union = dice_output.sum() + dice_target.sum() + eps
dice_loss = 2 * intersection / union
if force_positive:
return (1 - dice_loss)
if log:
dice_loss = torch.log(dice_loss)
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze()]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.squeeze()]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(probas, dim=1)
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
dice_loss = (2. * intersection / (cardinality + eps)).mean()
return (-1 * dice_loss)


def jaccard_loss(true, pred):
pass
def jaccard_loss(true, pred, eps=1e-7):
"""Computes the Jaccard loss, a.k.a the IoU loss.
Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the dice loss so we
return the negated dice loss.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
eps:
Returns:
jacc_loss: the Jaccard loss.
"""
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze()]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.squeeze()]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(probas, dim=1)
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
union = cardinality - intersection
jacc_loss = (intersection / (union + eps)).mean()
return (-1 * jacc_loss)


def ce_dice(true, pred, log=False, w1=1, w2=1):
Expand Down

0 comments on commit 4d6c7f9

Please sign in to comment.