-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AttributeError: module 'jax.config' has no attribute 'define_bool_state' when running big_vision
on tpu-vm-base
#3180
Comments
Hey @yuyangshu, currently https://github.com/google/jax/blob/main/jax/_src/config.py#L174 Can you try using |
I have same issue when I try to run with CUDA. |
Could you try upgrading flax to the latest version? |
Can confirm this is an issue with jax==0.4.25 Downgrade to jax==0.4.24 solves this problem. |
I have the same issue now (flax == 0.6.11 and jax==0.4.25). |
@jiyuuchc what flax version are you using? |
*** Attention: This is an external email. Use caution responding, opening attachments or clicking on links. ***
flax == 0.7.5
In fact, on jax==0.4.24, referencing ```jax.config.define_bool_state``` already raise a DeprecationWarning:
```
DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead. <function define_bool_state at 0x7f0fadff95a0>
```
So it seems all in the plan. I don't know why this is a surprise to start with.
…________________________________________
From: Marcus Chiam ***@***.***>
Sent: Thursday, March 7, 2024 12:45 PM
To: google/flax
Cc: Yu,Ji; Mention
Subject: Re: [google/flax] AttributeError: module 'jax.config' has no attribute 'define_bool_state' when running `big_vision` on `tpu-vm-base` (Issue #3180)
*** Attention: This is an external email. Use caution responding, opening attachments or clicking on links. ***
@jiyuuchc<https://urldefense.com/v3/__https://github.com/jiyuuchc__;!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj9CbKeUHA$> what flax version are you using?
—
Reply to this email directly, view it on GitHub<https://urldefense.com/v3/__https://github.com/google/flax/issues/3180*issuecomment-1984104784__;Iw!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj9ChieI8A$>, or unsubscribe<https://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AAKRPNRQK73NS6GYCVEN4F3YXCRUJAVCNFSM6AAAAAAZ66MP2CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOBUGEYDINZYGQ__;!!Cn_UX_p3!i-oEXZD7vYFxwFqBg5Mb1A4nY3l2wQI3--cgnnyPknoPzdoGjoKzL_PT15C3bmWEHk4W9CtBViw0Nj-s5WoOuw$>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
@Alxmrphi can you try upgrading to flax==0.7.5 or higher? |
I need to specify I want to use Cuda. The only command that installs a Jax version that finds my GPUs is the standard But this command installs jax 0.4.26 with the discussed error. When I manually downgrade to let's say 0.4.22 ( Any idea on what to do? |
@sudo-Boris Could you try specifying the version like so:
This has previously worked for me. |
@Alxmrphi Thank you for the quick reply! I just can't seem to find a version of cuda support jax that works with my CUDA version (12.1). For 0.4.23. I get For 0.4.24, I get the same warning at installation and the following annoying warning when executing the same code I guess I need to tweak around with the different library versions to find a combination that works. Thank you :) |
Oh, I've been there and your situation was what I was dealing with not so long ago (as you can see by my comments earlier up in this thread). I can let you know why I'm currently using that works with cuda 12.2 so I think hopefully would work with cuda 12.1 because I'm using an earlier version of JAX than you (patch version 20 instead of 22). This is from a HPC job so I'll just copy my environment setup script header. Just take the version numbers from the settings below and try those ones out.
If it doesn't work for you then I guess you just have to tweak the versions. I think you'll still need to use the jax[cuda12_pip]==version style of installation in a regular terminal setup. Best of luck :) |
This is because we are using evojax which needs an old version of flax. Old flax is not compatible with new jax ([see this discussion](google/flax#3180)). This is a temporary patch. Hopefully. We will figure out how to either fix evojax or stop using it for future experiments. Also, fix a bug on the sexual reproduction colab where I was importing a function that I deleted, causing the colab to fail. PiperOrigin-RevId: 638631920
I had the same issue and in the end these versions worked for me : jaxlib: 0.4.26+cuda12.cudnn89 (installed from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ) |
System information
pip show flax jax jaxlib
: flax: 0.6.7, jax: 0.4.13, jaxlib: 0.4.13Problem you have encountered:
Hello, I am seeing
when running the
big_vision
library ontpu-vm-base
.I saw a similar issue in another library where the issue seems to be resolved by fixing
jax
version to0.4.9
, but when I attempted that it did not work.I also tried fixing the versions of all packages in the
requirements.txt
ofbig_vision
, i.e.at the same time when fixing
jax
to0.4.9
, but that did not work either.I had to use a full
requirements.txt
obtained from runningpip freeze
in a local venv created on 2023-03-26 to get the library running on TPU again.What you expected to happen:
I was able to run
big_vision
ontpu-vm-base
on av3-8
TPU node without fixing any package versions as late as 2023-05-24.Logs, error messages, etc:
Steps to reproduce:
big_vision
locallybig_vision
to the TPU and start trainingThe text was updated successfully, but these errors were encountered: