Skip to content

Commit

Permalink
[Flax BERT/Roberta] few small fixes (huggingface#11558)
Browse files Browse the repository at this point in the history
* small fixes

* style
  • Loading branch information
patil-suraj authored May 3, 2021
1 parent a5d2967 commit 623281a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
15 changes: 6 additions & 9 deletions src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey

from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
Expand Down Expand Up @@ -92,9 +91,9 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
Expand All @@ -106,8 +105,8 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""

BERT_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -173,15 +172,13 @@ def setup(self):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))

# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
Expand Down Expand Up @@ -571,7 +568,7 @@ def __call__(
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/models/roberta/modeling_flax_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
Expand All @@ -73,8 +73,8 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""

ROBERTA_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -140,15 +140,13 @@ def setup(self):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))

# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
Expand Down

0 comments on commit 623281a

Please sign in to comment.