Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed May 17, 2021
1 parent 936b571 commit c73e353
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def __init__(
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)

Expand Down

0 comments on commit c73e353

Please sign in to comment.