Skip to content

Commit

Permalink
[SpeechEncoderDecoderModel] Fix bug in reshaping labels (#16748)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi authored Apr 14, 2022
1 parent 048443d commit de8b06f
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def forward(
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))

if not return_dict:
if loss is not None:
Expand Down

0 comments on commit de8b06f

Please sign in to comment.