You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to run this code in jax backend, and I've got a static inference shape result successfully. However, when I try to execute dynamically, I get this error message, which really confused me a lot.
Here is the code and what I got as results.
Hi @Isa-Fay,
I looked into this and found out that the mask argument has a shape of (4,), which is incompatible with the input shape of (2, 3). The mask should typically have the same shape as the input tensor, or be broadcastable to that shape.
You can use mask=np.random.rand(*[2, 3]) instead.
Thanks for your reply!
There is a problem with the input, but I think it's worth paying attention to the inconsistencies in the static and dynamic outputs, which may cause confusion and difficulty for users. Hopefully keras could add some input checks instead of passing illegal parameters directly to the backend.
I agree that the inconsistency between the static and dynamic outputs can be confusing for users. Adding input checks at the Keras layer level would definitely help prevent these kinds of errors by validating shapes and dimensions before passing them to the backend.
I have created a similar Pull Request #20237 for similar issue #20221 . After I get response on that I will add the val checks in all such Normalization Layers
I'm trying to run this code in jax backend, and I've got a static inference shape result successfully. However, when I try to execute dynamically, I get this error message, which really confused me a lot.
Here is the code and what I got as results.
version:
python 3.10
keras 3.5.0
jax 0.4.31
The static result:
The dynamic result:
The text was updated successfully, but these errors were encountered: