Skip to content
This repository has been archived by the owner on Jan 4, 2023. It is now read-only.

Commit

Permalink
Update losses.py
Browse files Browse the repository at this point in the history
Remove the cast.
  • Loading branch information
mfmezger committed Nov 29, 2019
1 parent ddce5b7 commit ac01eb6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def tversky_loss(true, logits, alpha, beta, eps=1e-7):
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.squeeze(1).to(torch.int64)]
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(logits, dim=1)
true_1_hot = true_1_hot.type(logits.type())
Expand Down

0 comments on commit ac01eb6

Please sign in to comment.