Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
add label smoothing to loss (#942)
Browse files Browse the repository at this point in the history
* add label smoothing to loss

* correct comment

* use uniformly distributed smoothing

* add docstring examples
  • Loading branch information
DeNeutoy authored Mar 2, 2018
1 parent f81e27a commit e657353
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
35 changes: 29 additions & 6 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
targets: torch.LongTensor,
weights: torch.FloatTensor,
batch_average: bool = True) -> torch.FloatTensor:
batch_average: bool = True,
label_smoothing: float = None) -> torch.FloatTensor:
"""
Computes the cross entropy loss of a sequence, weighted with respect to
some user provided weights. Note that the weighting here is not the same as
Expand All @@ -412,6 +413,11 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
batch_average : bool, optional, (default = True).
A bool indicating whether the loss should be averaged across the batch,
or returned as a vector of losses per batch element.
label_smoothing : ``float``, optional (default = None)
Whether or not to apply label smoothing to the cross-entropy loss.
For example, with a label smoothing value of 0.2, a 4 class classifcation
target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
the correct label.
Returns
-------
Expand All @@ -427,11 +433,20 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
# shape : (batch * max_len, 1)
targets_flat = targets.view(-1, 1).long()

# Contribution to the negative log likelihood only comes from the exact indices
# of the targets, as the target distributions are one-hot. Here we use torch.gather
# to extract the indices of the num_classes dimension which contribute to the loss.
# shape : (batch * sequence_length, 1)
negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
if label_smoothing is not None and label_smoothing > 0.0:
num_classes = logits.size(-1)
smoothing_value = label_smoothing / num_classes
# Fill all the correct indices with 1 - smoothing value.
one_hot_targets = zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing)
smoothed_targets = one_hot_targets + smoothing_value
negative_log_likelihood_flat = - log_probs_flat * smoothed_targets
negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
else:
# Contribution to the negative log likelihood only comes from the exact indices
# of the targets, as the target distributions are one-hot. Here we use torch.gather
# to extract the indices of the num_classes dimension which contribute to the loss.
# shape : (batch * sequence_length, 1)
negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
# shape : (batch, sequence_length)
negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
# shape : (batch, sequence_length)
Expand Down Expand Up @@ -482,6 +497,14 @@ def ones_like(tensor: torch.Tensor) -> torch.Tensor:
return tensor.clone().fill_(1)


def zeros_like(tensor: torch.Tensor) -> torch.Tensor:
"""
Use clone() + fill_() to make sure that a zeros tensor ends up on the right
device at runtime.
"""
return tensor.clone().fill_(0)


def combine_tensors(combination: str, tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Combines a list of tensors using element-wise operations and concatenation, specified by a
Expand Down
28 changes: 24 additions & 4 deletions tests/nn/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,10 @@ def test_sequence_cross_entropy_with_logits_masks_loss_correctly(self):
tensor[3, :, :] = 0
weights = (tensor != 0.0)[:, :, 0].long().squeeze(-1)
tensor2 = tensor.clone()
tensor[0, 3:, :] = 2
tensor[1, 4:, :] = 13
tensor[2, 2:, :] = 234
tensor[3, :, :] = 65
tensor2[0, 3:, :] = 2
tensor2[1, 4:, :] = 13
tensor2[2, 2:, :] = 234
tensor2[3, :, :] = 65
targets = torch.LongTensor(numpy.random.randint(0, 3, [5, 7]))
targets *= weights

Expand All @@ -437,6 +437,26 @@ def test_sequence_cross_entropy_with_logits_masks_loss_correctly(self):
loss2 = util.sequence_cross_entropy_with_logits(tensor2, targets, weights)
assert loss.data.numpy() == loss2.data.numpy()


def test_sequence_cross_entropy_with_logits_smooths_labels_correctly(self):
tensor = torch.rand([1, 3, 4])
targets = torch.LongTensor(numpy.random.randint(0, 3, [1, 3]))

tensor = Variable(tensor)
targets = Variable(targets)
weights = Variable(torch.ones([2, 3]))
loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, label_smoothing=0.1)

correct_loss = 0.0
for prediction, label in zip(tensor.squeeze(0), targets.squeeze(0)):
prediction = torch.nn.functional.log_softmax(prediction)
correct_loss += prediction[label] * 0.9
# incorrect elements
correct_loss += prediction.sum() * 0.1/4
# Average over sequence.
correct_loss = - correct_loss / 3
numpy.testing.assert_array_almost_equal(loss.data.numpy(), correct_loss.data.numpy())

def test_sequence_cross_entropy_with_logits_averages_batch_correctly(self):
# test batch average is the same as dividing the batch averaged
# loss by the number of batches containing any non-padded tokens.
Expand Down

0 comments on commit e657353

Please sign in to comment.