From 80468251bc3771d53427f77aa2dc9d49a55d2bf0 Mon Sep 17 00:00:00 2001 From: Dan Jones Date: Thu, 11 Aug 2022 14:45:04 +0100 Subject: [PATCH] Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training (#18486) * changing BartLearnedPositionalEmbedding forward signature and references to it * removing debugging dead code (thanks style checker) * blackened modeling_bart file * removing copy inconsistencies via make fix-copies * changing references to copied signatures in Bart variants * make fix-copies once more * using expand over repeat (thanks @michaelbenayoun) * expand instead of repeat for all model copies Co-authored-by: Daniel Jones --- src/transformers/models/bart/modeling_bart.py | 26 +++++++++++-------- .../models/mbart/modeling_mbart.py | 23 +++++++++------- src/transformers/models/mvp/modeling_mvp.py | 22 ++++++++++------ .../models/plbart/modeling_plbart.py | 26 +++++++++++-------- .../models/trocr/modeling_trocr.py | 19 +++++++++----- 5 files changed, 70 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 8411cc6cefefed..525da6f34b06cf 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -128,12 +128,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) + ).expand(bsz, -1) + return super().forward(positions + self.offset) @@ -788,17 +790,17 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - embed_pos = self.embed_positions(input_shape) + embed_pos = self.embed_positions(input) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) @@ -1015,10 +1017,12 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") @@ -1026,7 +1030,7 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input) * self.embed_scale attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length @@ -1038,7 +1042,7 @@ def forward( encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + positions = self.embed_positions(input, past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 16ea95bc0aedde..66011fe6a73d0a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) + ).expand(bsz, -1) + return super().forward(positions + self.offset) @@ -783,17 +785,18 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - embed_pos = self.embed_positions(input_shape) + embed_pos = self.embed_positions(input) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) @@ -1013,10 +1016,12 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") @@ -1036,7 +1041,7 @@ def forward( encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + positions = self.embed_positions(input, past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index d3d239c4cff125..37c1a7d837f7ba 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) + ).expand(bsz, -1) + return super().forward(positions + self.offset) @@ -895,17 +897,19 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - embed_pos = self.embed_positions(input_shape) + embed_pos = self.embed_positions(input) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) @@ -1144,10 +1148,12 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input_ids.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") @@ -1167,7 +1173,7 @@ def forward( encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + positions = self.embed_positions(input, past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index d03ddf33ebfa7a..d86decb568192e 100755 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -131,12 +131,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) + ).expand(bsz, -1) + return super().forward(positions + self.offset) @@ -759,17 +761,17 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - embed_pos = self.embed_positions(input_shape) + embed_pos = self.embed_positions(input) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) @@ -987,10 +989,12 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") @@ -998,7 +1002,7 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input) * self.embed_scale attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length @@ -1010,7 +1014,7 @@ def forward( encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + positions = self.embed_positions(input, past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index a79e5e901d67c4..e25f73c8b7d3b5 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -87,12 +87,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) + ).expand(bsz, -1) + return super().forward(positions + self.offset) @@ -626,10 +628,11 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) + input = input_ids + input_ids = input_ids.view(-1, input.shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") @@ -640,7 +643,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale if self.config.use_learned_position_embeddings: - embed_pos = self.embed_positions(input_shape, past_key_values_length=past_key_values_length) + embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length) else: embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) @@ -651,6 +654,8 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + input_shape = input.shape + attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length )