Skip to content

Commit

Permalink
attempt to fix attn mask device
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Aug 9, 2022
1 parent 8cb5ecd commit 98e443c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value

# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand All @@ -534,7 +533,9 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
Expand Down

0 comments on commit 98e443c

Please sign in to comment.