Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add general conversion script #10809

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Mar 19, 2021

What does this PR do?

This PR changes the weight architecture of FlaxBertModel so that it corresponds 1-to-1 to PyTorch's version of BertModel. This means that some weights had to be renamed (e.g. "layer_norm" -> "LayerNorm" since PyTorch uses "LayerNorm") and also some new flax.linen.Modules, such as FlaxBertSelfOutput had to be created.

As can be seen, the PT=>Flax conversion function is now kept very general and can be applied to all models so that we can fully delete any model-specific conversion logic.

The PR has one drawback however:

Flax official SelfAttention Module cannot be used anymore since it doesn't give us enough flexibility to convert PyTorch weights to flax weights without having a model-specific conversion function. FlaxBERT's new attention modules fully correspond to PyTorchBERT's attention modules and are IMO still kept quite short by relying on Flax's dot_product_attention function. Another drawback is that for auto-regressive Transformers models we will have to manually add all the code corresponding to cached / auto-regressive attention to the attention module (which we do for PyTorch anyways) instead of being able to use already existing code of nn.linen.SelfAttention -> see here: https://github.com/google/flax/blob/e31063da71bd7a4df137b000df6a48b0cea35a2b/flax/linen/attention.py#L202.

All in all, rewriting parts of flax.linen.SelfAttention is the right choice here though because it allows us to have a much cleaner conversion function with very little downside IMO (slightly higher maintenance because we need to copy-paste a bit more code).

@LysandreJik @sgugger - could you check if you agree more or less with my solution here (below I left some comments to showcase the trade-offs a bit better). I'll clean the code & upload the new weight structure then :-)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@patrickvonplaten patrickvonplaten changed the title [Flax] Add general conversion script [WIP][Flax] Add general conversion script Mar 19, 2021
)

# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for pt_key, pt_tensor in pt_state_dict.items():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion function can be kept short & concise by forcing FlaxBert to have the exact same model names and architecture as PyTorch's BERT

@@ -121,11 +122,6 @@ def params(self, params: Union[Dict, FrozenDict]):
)
self._params = freeze(params)

@staticmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can delete all model-specific conversion methods now :-)

elif pt_tuple_key[-1] == "beta":
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)

# THIS AND MORE WOULD BE NEEDED IF ATTENTION FN IS USED
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and much more code would have to be added if we would decide to stick with flax.linen.SelfAttention

config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation

def setup(self):
self.self_attention = nn.attention.SelfAttention(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flax.linen.SelfAttention is removed and query, key and value weights are created in the same way we do it for PyTorch

if not deterministic and self.dropout_rate > 0.0:
dropout_rng = self.make_rng("dropout")

attn_output = dot_product_attention(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flax's flax.linen.dot_product_attention function can still save us quite some code


return jax_state

# THIS AND MORE WOULD BE NEEDED IF WE KEEP nn.self_attention
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keeping flax.linen.SelfAttention would force us to have all these model-specific conversion look-up tables which I don't think is worth it

@patrickvonplaten patrickvonplaten changed the title [WIP][Flax] Add general conversion script [Flax] Add general conversion script Mar 23, 2021
@patrickvonplaten patrickvonplaten changed the title [Flax] Add general conversion script [WIP][Flax] Add general conversion script Mar 23, 2021
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! My only concern is to make sure we don't lose any performance by not using nn.linen.SelfAttention. If we are just using the same code as its implementation, there is no reason for that but it's good to double-check.
Otherwise, I agree it's better to re-implement it than to have custom weight loading logic..

src/transformers/modeling_flax_pytorch_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_pytorch_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_pytorch_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_pytorch_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_flax_pytorch_utils.py Outdated Show resolved Hide resolved
src/transformers/models/bert/modeling_flax_bert.py Outdated Show resolved Hide resolved
src/transformers/models/bert/modeling_flax_bert.py Outdated Show resolved Hide resolved
src/transformers/models/bert/modeling_flax_bert.py Outdated Show resolved Hide resolved
src/transformers/models/bert/modeling_flax_bert.py Outdated Show resolved Hide resolved
src/transformers/models/roberta/modeling_flax_roberta.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, this is very clean! If this has no performance impact, this is a very welcome change.


class FlaxBertSelfOutput(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this handled automatically when using FP16/bFP16?

src/transformers/models/roberta/modeling_flax_roberta.py Outdated Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor Author

Great work! My only concern is to make sure we don't lose any performance by not using nn.linen.SelfAttention. If we are just using the same code as its implementation, there is no reason for that but it's good to double-check.
Otherwise, I agree it's better to re-implement it than to have custom weight loading logic..

Great! Yeah, I'll talk with @avital about this next week (hopefully) :-)

@patrickvonplaten patrickvonplaten merged commit 8780caa into huggingface:master Mar 30, 2021
@patrickvonplaten patrickvonplaten deleted the add_flax_conversion branch March 30, 2021 09:14
@patrickvonplaten patrickvonplaten changed the title [WIP][Flax] Add general conversion script [Flax] Add general conversion script Mar 31, 2021
Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
* save intermediate

* finish first version

* delete some more

* improve import

* fix roberta

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* small corrections

* apply all comments

* fix deterministic

* make fix-copies

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants