You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tested the mentioned code on colab CPU, GPU and TPU v2 with JAX 0.4.26 and Flax 0.8.4. Also tested on Macbook CPU with JAX version 0.4.30 and Flax 0.8.5. I could not reproduce the error that you mentioned and it works fine.
import flax.linen as nn
import jax, jax.numpy as jnp
x = jax.random.normal(jax.random.key(0), (2, 3))
layer = nn.LSTMCell(features=4)
carry = layer.initialize_carry(jax.random.key(1), x.shape)
variables = layer.init(jax.random.key(2), carry, x)
new_carry, out = layer.apply(variables, carry, x)
Running the code gives this error. This code comes from the documentation
flax.errors.AssignSubModuleError: Submodule LSTMCell must be defined in
setup()
or in a method wrapped in@compact
(https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)The text was updated successfully, but these errors were encountered: