Skip to content

Commit

Permalink
Change BartLearnedPositionalEmbedding's forward method signature to s…
Browse files Browse the repository at this point in the history
…upport Opacus training (#18486)

* changing BartLearnedPositionalEmbedding forward signature and references to it

* removing debugging dead code (thanks style checker)

* blackened modeling_bart file

* removing copy inconsistencies via make fix-copies

* changing references to copied signatures in Bart variants

* make fix-copies once more

* using expand over repeat (thanks @michaelbenayoun)

* expand instead of repeat for all model copies

Co-authored-by: Daniel Jones <jonesdaniel@microsoft.com>
  • Loading branch information
donebydan and Daniel Jones authored Aug 11, 2022
1 parent 3f0707b commit 8046825
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 46 deletions.
26 changes: 15 additions & 11 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -788,17 +790,17 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input = input_ids
input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -1015,18 +1017,20 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input) * self.embed_scale

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
Expand All @@ -1038,7 +1042,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)

hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
23 changes: 14 additions & 9 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -783,17 +785,18 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -1013,10 +1016,12 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

Expand All @@ -1036,7 +1041,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)

hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
22 changes: 14 additions & 8 deletions src/transformers/models/mvp/modeling_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -895,17 +897,19 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -1144,10 +1148,12 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input_ids.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

Expand All @@ -1167,7 +1173,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)

hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
26 changes: 15 additions & 11 deletions src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -759,17 +761,17 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input = input_ids
input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input)

hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down Expand Up @@ -987,18 +989,20 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input) * self.embed_scale

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
Expand All @@ -1010,7 +1014,7 @@ def forward(
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
positions = self.embed_positions(input, past_key_values_length)

hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
Expand Down
19 changes: 12 additions & 7 deletions src/transformers/models/trocr/modeling_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
).expand(bsz, -1)

return super().forward(positions + self.offset)


Expand Down Expand Up @@ -626,10 +628,11 @@ def forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input = input_ids
input_ids = input_ids.view(-1, input.shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

Expand All @@ -640,7 +643,7 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

if self.config.use_learned_position_embeddings:
embed_pos = self.embed_positions(input_shape, past_key_values_length=past_key_values_length)
embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
else:
embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)

Expand All @@ -651,6 +654,8 @@ def forward(

hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

input_shape = input.shape

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
Expand Down

0 comments on commit 8046825

Please sign in to comment.