-
Notifications
You must be signed in to change notification settings - Fork 25.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Add general conversion script #10809
[Flax] Add general conversion script #10809
Conversation
) | ||
|
||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer | ||
for pt_key, pt_tensor in pt_state_dict.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conversion function can be kept short & concise by forcing FlaxBert to have the exact same model names and architecture as PyTorch's BERT
@@ -121,11 +122,6 @@ def params(self, params: Union[Dict, FrozenDict]): | |||
) | |||
self._params = freeze(params) | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can delete all model-specific conversion methods now :-)
elif pt_tuple_key[-1] == "beta": | ||
pt_tuple_key = pt_tuple_key[:-1] + ("bias",) | ||
|
||
# THIS AND MORE WOULD BE NEEDED IF ATTENTION FN IS USED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and much more code would have to be added if we would decide to stick with flax.linen.SelfAttention
config: BertConfig | ||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | ||
|
||
def setup(self): | ||
self.self_attention = nn.attention.SelfAttention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flax.linen.SelfAttention
is removed and query, key and value weights are created in the same way we do it for PyTorch
if not deterministic and self.dropout_rate > 0.0: | ||
dropout_rng = self.make_rng("dropout") | ||
|
||
attn_output = dot_product_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flax's flax.linen.dot_product_attention
function can still save us quite some code
|
||
return jax_state | ||
|
||
# THIS AND MORE WOULD BE NEEDED IF WE KEEP nn.self_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keeping flax.linen.SelfAttention
would force us to have all these model-specific conversion look-up tables which I don't think is worth it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! My only concern is to make sure we don't lose any performance by not using nn.linen.SelfAttention
. If we are just using the same code as its implementation, there is no reason for that but it's good to double-check.
Otherwise, I agree it's better to re-implement it than to have custom weight loading logic..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, this is very clean! If this has no performance impact, this is a very welcome change.
|
||
class FlaxBertSelfOutput(nn.Module): | ||
config: BertConfig | ||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this handled automatically when using FP16/bFP16?
Great! Yeah, I'll talk with @avital about this next week (hopefully) :-) |
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
…aten/transformers into add_flax_conversion
* save intermediate * finish first version * delete some more * improve import * fix roberta * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * small corrections * apply all comments * fix deterministic * make fix-copies Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
What does this PR do?
This PR changes the weight architecture of
FlaxBertModel
so that it corresponds 1-to-1 to PyTorch's version ofBertModel
. This means that some weights had to be renamed (e.g. "layer_norm" -> "LayerNorm" since PyTorch uses "LayerNorm") and also some newflax.linen.Modules
, such asFlaxBertSelfOutput
had to be created.As can be seen, the PT=>Flax conversion function is now kept very general and can be applied to all models so that we can fully delete any model-specific conversion logic.
The PR has one drawback however:
Flax official SelfAttention Module cannot be used anymore since it doesn't give us enough flexibility to convert PyTorch weights to flax weights without having a model-specific conversion function. FlaxBERT's new attention modules fully correspond to PyTorchBERT's attention modules and are IMO still kept quite short by relying on Flax's
dot_product_attention
function. Another drawback is that for auto-regressive Transformers models we will have to manually add all the code corresponding to cached / auto-regressive attention to the attention module (which we do for PyTorch anyways) instead of being able to use already existing code ofnn.linen.SelfAttention
-> see here: https://github.com/google/flax/blob/e31063da71bd7a4df137b000df6a48b0cea35a2b/flax/linen/attention.py#L202.All in all, rewriting parts of
flax.linen.SelfAttention
is the right choice here though because it allows us to have a much cleaner conversion function with very little downside IMO (slightly higher maintenance because we need to copy-paste a bit more code).@LysandreJik @sgugger - could you check if you agree more or less with my solution here (below I left some comments to showcase the trade-offs a bit better). I'll clean the code & upload the new weight structure then :-)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.