Skip to content

Commit

Permalink
Add head_mask and decoder_head_mask to FSMT (huggingface#9819)
Browse files Browse the repository at this point in the history
* Add {decoder_,}head_mask to fsmt_modeling.py

* Enable test_headmasking and some changes to docs

* Remove test_head_masking flag from fsmt test file

Remove test_head_masking flag from test_modeling_fsmt.py
since test_head_masking is set to be True by default (thus it is redundant to store).

* Merge master and remove test_head_masking = True

* Rebase necessary due to an update of jaxlib

* Remove test_head_masking=True in tests/test_modeling_fsmt.py
as it is redundant.
  • Loading branch information
stancld committed Feb 1, 2021
1 parent 74f16b8 commit 0c6c0af
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 17 deletions.
117 changes: 101 additions & 16 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@
also be used by default. If you want to change padding behavior, you should read
:func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify. See diagram 1 in the paper for more info on
the default strategy
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
Expand Down Expand Up @@ -282,7 +293,11 @@ def triu_onnx(x, diagonal=0):


def _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
config,
input_ids,
decoder_input_ids=None,
decoder_padding_mask=None,
causal_mask_dtype=torch.float32,
):
"""
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
Expand Down Expand Up @@ -377,21 +392,27 @@ def __init__(self, config: FSMTConfig):
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)

def forward(self, x, encoder_padding_mask, output_attentions=False):
def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
x (:obj:`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (:obj:`torch.ByteTensor`): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
query=x,
key=x,
key_padding_mask=encoder_padding_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
Expand Down Expand Up @@ -432,21 +453,32 @@ def __init__(self, config: FSMTConfig, embed_tokens):
) # type: List[EncoderLayer]

def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True
self,
input_ids,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
"""
Args:
input_ids (LongTensor): tokens in the source language of shape
input_ids (:obj:`torch.LongTensor`): tokens in the source language of shape
`(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens
attention_mask (:obj:`torch.LongTensor`): indicating which indices are padding tokens
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
Returns:
BaseModelOutput or Tuple comprised of:
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len,
batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
- **x** (:obj:`torch.Tensor`): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (:obj:`Tuple(torch.FloatTensor`)): all intermediate hidden states of shape
`(src_len, batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
- **all_attentions** (:obj:`Tuple(torch.FloatTensor`)): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
Expand All @@ -463,7 +495,12 @@ def forward(

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
x = x.transpose(0, 1) # T x B x C -> B x T x C
encoder_states += (x,)
Expand All @@ -473,7 +510,12 @@ def forward(
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
x, attn = encoder_layer(
x,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)

if output_attentions:
all_attentions = all_attentions + (attn,)
Expand Down Expand Up @@ -522,6 +564,8 @@ def forward(
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
layer_head_mask=None,
encoder_layer_head_mask=None,
decoder_padding_mask=None,
output_attentions=False,
):
Expand All @@ -537,6 +581,7 @@ def forward(
layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
Expand All @@ -551,6 +596,7 @@ def forward(
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
layer_head_mask=encoder_layer_head_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
Expand Down Expand Up @@ -611,6 +657,8 @@ def forward(
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
Expand All @@ -622,12 +670,24 @@ def forward(
EMNLP 2019).
Args:
input_ids (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch, tgt_len)`):
previous decoder outputs for teacher forcing
encoder_hidden_states: output from the encoder, used for
encoder-side attention
encoder_padding_mask: for ignoring pad tokens
past_key_values (dict or None): dictionary used for storing state during generation
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
Returns:
BaseModelOutputWithPast or tuple:
Expand Down Expand Up @@ -662,6 +722,12 @@ def forward(
all_self_attns = () if output_attentions else None
all_cross_attns = () if output_attentions else None
next_decoder_cache = []

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
Expand All @@ -681,6 +747,8 @@ def forward(
decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state,
causal_mask=decoder_causal_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -761,6 +829,7 @@ def forward(
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None,
layer_head_mask: Optional[Tensor] = None,
output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
Expand Down Expand Up @@ -830,6 +899,13 @@ def forward(

attn_weights = F.softmax(attn_weights, dim=-1)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

if output_attentions:
# make sure that attn_weights are included in graph
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
Expand Down Expand Up @@ -923,6 +999,8 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Tuple] = None,
past_key_values=None,
use_cache=None,
Expand Down Expand Up @@ -958,6 +1036,7 @@ def forward(
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand All @@ -977,6 +1056,8 @@ def forward(
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1052,6 +1133,8 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
labels=None,
Expand Down Expand Up @@ -1080,6 +1163,8 @@ def forward(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,20 @@ def prepare_fsmt_inputs_dict(
config,
input_ids,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}


Expand All @@ -126,7 +134,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False

def setUp(self):
Expand Down

0 comments on commit 0c6c0af

Please sign in to comment.