Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul got incompatible shapes for broadcasting #20219

Open
Isa-Fay opened this issue Sep 6, 2024 · 4 comments
Open

mul got incompatible shapes for broadcasting #20219

Isa-Fay opened this issue Sep 6, 2024 · 4 comments

Comments

@Isa-Fay
Copy link

Isa-Fay commented Sep 6, 2024

I'm trying to run this code in jax backend, and I've got a static inference shape result successfully. However, when I try to execute dynamically, I get this error message, which really confused me a lot.
Here is the code and what I got as results.

version:
python 3.10
keras 3.5.0
jax 0.4.31

import os
import re
import jax
import numpy as np
os.environ['KERAS_BACKEND']='jax'
import keras

layer = keras.layers.BatchNormalization(
    axis=-1,
    momentum=0.99,
    epsilon=0.001,
    center=False,
    scale=False,
    beta_initializer="zeros",
    gamma_initializer="ones",
    moving_mean_initializer="zeros",
    moving_variance_initializer="ones",
    synchronized=False,
    trainable=True,
    autocast=True,
)

result_static = layer.compute_output_shape([2, 3])

result_dynamic = layer(
    inputs=np.random.rand(*[2, 3]),
    training=True,
    mask=np.random.rand(*[4]),
)

The static result:
5

The dynamic result:
6

@sanskarmodi8
Copy link
Contributor

Hi @Isa-Fay,
I looked into this and found out that the mask argument has a shape of (4,), which is incompatible with the input shape of (2, 3). The mask should typically have the same shape as the input tensor, or be broadcastable to that shape.
You can use mask=np.random.rand(*[2, 3]) instead.

@Isa-Fay
Copy link
Author

Isa-Fay commented Sep 9, 2024

Thanks for your reply!
There is a problem with the input, but I think it's worth paying attention to the inconsistencies in the static and dynamic outputs, which may cause confusion and difficulty for users. Hopefully keras could add some input checks instead of passing illegal parameters directly to the backend.

@sanskarmodi8
Copy link
Contributor

sanskarmodi8 commented Sep 9, 2024

Thank you for your feedback!

I agree that the inconsistency between the static and dynamic outputs can be confusing for users. Adding input checks at the Keras layer level would definitely help prevent these kinds of errors by validating shapes and dimensions before passing them to the backend.

I have created a similar Pull Request #20237 for similar issue #20221 . After I get response on that I will add the val checks in all such Normalization Layers

@sachinprasadhs
Copy link
Collaborator

This is now working in Keras-nightly with the above linked PR fix, attaching the expected outcome in the Gist here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants