Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jax.config' has no attribute 'define_bool_state' when running big_vision on tpu-vm-base #3180

Open
yuyangshu opened this issue Jul 5, 2023 · 13 comments
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@yuyangshu
Copy link

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): tpu-vm-base
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax: 0.6.7, jax: 0.4.13, jaxlib: 0.4.13
  • Python version: 3.8
  • GPU/TPU model and memory: tpu v3-8
  • CUDA version (if applicable): N/A

Problem you have encountered:

Hello, I am seeing

AttributeError: module 'jax.config' has no attribute 'define_bool_state'

when running the big_vision library on tpu-vm-base.

I saw a similar issue in another library where the issue seems to be resolved by fixing jax version to 0.4.9, but when I attempted that it did not work.

I also tried fixing the versions of all packages in the requirements.txt of big_vision, i.e.

absl-py==1.4.0
clu==0.0.8
einops==0.6.0
flax==0.6.7
git+https://github.com/google/flaxformer
git+https://github.com/deepmind/optax.git
git+https://github.com/akolesnikoff/panopticapi.git@mute
overrides==7.3.1
tensorflow==2.12.0
tfds-nightly==4.8.3.dev202303250044
tensorflow-addons==0.19.0
tensorflow-text==2.12.0
tensorflow-gan==2.1.0

at the same time when fixing jax to 0.4.9, but that did not work either.

I had to use a full requirements.txt obtained from running pip freeze in a local venv created on 2023-03-26 to get the library running on TPU again.

What you expected to happen:

I was able to run big_vision on tpu-vm-base on a v3-8 TPU node without fixing any package versions as late as 2023-05-24.

Logs, error messages, etc:

Installing collected packages: libtpu-nightly, zipp, numpy, scipy, opt-einsum, ml-dtypes, importlib-metadata, jaxlib, jax
Successfully installed importlib-metadata-6.7.0 jax-0.4.13 jaxlib-0.4.13 libtpu-nightly-0.1.dev20230622 ml-dtypes-0.2.0 numpy-1.24.4 opt-einsum-3.3.0 scipy-1.10.1 zipp-3.15.0
Collecting git+https://github.com/google/flaxformer (from -r big_vision/requirements.txt (line 5))
  Cloning https://github.com/google/flaxformer to /tmp/pip-req-build-925ai1ze
  Running command git clone --filter=blob:none --quiet https://github.com/google/flaxformer /tmp/pip-req-build-925ai1ze
  Resolved https://github.com/google/flaxformer to commit 9adaa4467cf17703949b9f537c3566b99de1b416
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'

[omitted]

Collecting flax==0.6.7 (from -r big_vision/requirements.txt (line 4))
  Downloading flax-0.6.7-py3-none-any.whl (214 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 214.2/214.2 kB 28.6 MB/s eta 0:00:00

[omitted]

Building wheels for collected packages: flaxformer, optax, panopticapi, ml-collections, promise
  Building wheel for flaxformer (pyproject.toml): started
  Building wheel for flaxformer (pyproject.toml): finished with status 'done'
  Created wheel for flaxformer: filename=flaxformer-0.8.1-py3-none-any.whl size=321948 sha256=df38d4209289e8a71a245b56f95490ec0ce9c2bbfaa164fd00d1b7e2f80b5869

[omitted]

Successfully installed MarkupSafe-2.1.3 Pillow-10.0.0 PyYAML-6.0 absl-py-1.4.0 aqtp-0.1.1 array-record-0.4.0 astunparse-1.6.3 cached_property-1.5.2 cachetools-5.3.1 certifi-2023.5.7 charset-normalizer-3.1.0 chex-0.1.7 click-8.1.3 cloudpickle-2.2.1 clu-0.0.8 contextlib2-21.6.0 dacite-1.8.1 decorator-5.1.1 dm-tree-0.1.8 einops-0.6.0 etils-1.3.0 flatbuffers-23.5.26 flax-0.6.7 flaxformer-0.8.1 gast-0.4.0 google-auth-2.21.0 google-auth-oauthlib-1.0.0 google-pasta-0.2.0 googleapis-common-protos-1.59.1 grpcio-1.56.0 h5py-3.9.0 idna-3.4 importlib-resources-5.12.0 keras-2.12.0 libclang-16.0.0 markdown-3.4.3 markdown-it-py-3.0.0 mdurl-0.1.2 ml-collections-0.1.1 msgpack-1.0.5 nest_asyncio-1.5.6 numpy-1.23.5 oauthlib-3.2.2 optax-0.1.5 orbax-0.1.7 overrides-7.3.1 packaging-23.1 panopticapi-0.1 promise-2.3 protobuf-4.23.3 psutil-5.9.5 pyasn1-0.5.0 pyasn1-modules-0.3.0 pygments-2.15.1 requests-2.31.0 requests-oauthlib-1.3.1 rich-13.4.2 rsa-4.9 six-1.16.0 tensorboard-2.12.3 tensorboard-data-server-0.7.1 tensorflow-2.12.0 tensorflow-addons-0.19.0 tensorflow-datasets-4.9.2 tensorflow-estimator-2.12.0 tensorflow-gan-2.1.0 tensorflow-hub-0.13.0 tensorflow-io-gcs-filesystem-0.32.0 tensorflow-metadata-1.13.1 tensorflow-probability-0.20.1 tensorflow-text-2.12.0 tensorstore-0.1.40 termcolor-2.3.0 tfds-nightly-4.8.3.dev202303250044 toml-0.10.2 toolz-0.12.0 tqdm-4.65.0 typeguard-4.0.0 typing-extensions-4.7.1 urllib3-1.26.16 werkzeug-2.3.6 wheel-0.40.0 wrapt-1.14.1
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/yuyang/big_vision/train.py", line 28, in <module>
    import big_vision.evaluators.common as eval_common
  File "/home/yuyang/big_vision/evaluators/common.py", line 22, in <module>
    import flax
  File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/__init__.py", line 18, in <module>
    from .configurations import (
  File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 93, in <module>
    flax_filter_frames = define_bool_state(
  File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 42, in define_bool_state
    return jax_config.define_bool_state('flax_' + name, default, help)
AttributeError: module 'jax.config' has no attribute 'define_bool_state'

Steps to reproduce:

  1. Check out big_vision locally
git@github.com:google-research/big_vision.git
  1. Create a TPU node
gcloud compute tpus tpu-vm create $VM_NAME --zone=$ZONE --accelerator-type=v3-8 --version=tpu-vm-base 
  1. Upload big_vision to the TPU and start training
gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $VM_NAME: --zone=$ZONE --worker=all
gcloud compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/vit_s16_i1k.py --workdir gs://$BUCKET_NAME/workdirs/`date '+%m-%d_%H%M'`"
@cgarciae
Copy link
Collaborator

cgarciae commented Jul 5, 2023

Hey @yuyangshu, currently jax does expose define_bool_state:

https://github.com/google/jax/blob/main/jax/_src/config.py#L174

Can you try using flax==0.6.11?

@hwh7
Copy link

hwh7 commented Jul 7, 2023

I have same issue when I try to run with CUDA.
Trying with flax==0.6.11 doesn't work for me.

@chiamp
Copy link
Collaborator

chiamp commented Sep 28, 2023

Could you try upgrading flax to the latest version?

@jiyuuchc
Copy link

Can confirm this is an issue with jax==0.4.25

Downgrade to jax==0.4.24 solves this problem.

@Alxmrphi
Copy link

Alxmrphi commented Mar 7, 2024

I have the same issue now (flax == 0.6.11 and jax==0.4.25).
Downgrading to 0.4.24 now gives me a different error, to which the solution is to downgrade to 0.4.23.
Let's see at what version this ends ...

@chiamp
Copy link
Collaborator

chiamp commented Mar 7, 2024

@jiyuuchc what flax version are you using?

@jiyuuchc
Copy link

jiyuuchc commented Mar 7, 2024 via email

@chiamp
Copy link
Collaborator

chiamp commented Mar 7, 2024

@Alxmrphi can you try upgrading to flax==0.7.5 or higher?

@sudo-Boris
Copy link

I need to specify I want to use Cuda. The only command that installs a Jax version that finds my GPUs is the standard
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html.

But this command installs jax 0.4.26 with the discussed error. When I manually downgrade to let's say 0.4.22 (pip install jax==0.4.22), I again get a version that doesn't find my GPU...

Any idea on what to do?

@Alxmrphi
Copy link

Alxmrphi commented May 1, 2024

@sudo-Boris Could you try specifying the version like so:

pip install "jax[cuda12_pip]"==0.4.22 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

This has previously worked for me.

@sudo-Boris
Copy link

@Alxmrphi Thank you for the quick reply!

I just can't seem to find a version of cuda support jax that works with my CUDA version (12.1).

For 0.4.23. I get WARNING: jax 0.4.23 does not provide the extra 'cuda12-pip' at installation and when executing code I get jaxlib.xla_extension.XlaRuntimeError: INTERNAL: XLA requires ptxas version 11.8 or higher.

For 0.4.24, I get the same warning at installation and the following annoying warning when executing the same code
W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

I guess I need to tweak around with the different library versions to find a combination that works.

Thank you :)

@Alxmrphi
Copy link

Alxmrphi commented May 1, 2024

Oh, I've been there and your situation was what I was dealing with not so long ago (as you can see by my comments earlier up in this thread). I can let you know why I'm currently using that works with cuda 12.2 so I think hopefully would work with cuda 12.1 because I'm using an earlier version of JAX than you (patch version 20 instead of 22). This is from a HPC job so I'll just copy my environment setup script header. Just take the version numbers from the settings below and try those ones out.

module purge
module load gcc/11.3.0
module load python/3.10
module load cuda/12.2
module load cudnn
module load scipy-stack

virtualenv --no-download $ENVDIR
source $ENVDIR/bin/activate

pip install --no-index torch torchvision
pip install --no-index flax==0.7.5+computecanada
pip install --no-index wandb
pip install jax==0.4.20+computecanada --no-index
pip install orbax_checkpoint==0.5.2+computecanada --no-index

If it doesn't work for you then I guess you just have to tweak the versions. I think you'll still need to use the jax[cuda12_pip]==version style of installation in a regular terminal setup.

Best of luck :)

@chiamp chiamp added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label May 1, 2024
copybara-service bot pushed a commit to google-research/self-organising-systems that referenced this issue May 30, 2024
This is because we are using evojax which needs an old version of flax. Old flax is not compatible with new jax ([see this discussion](google/flax#3180)).

This is a temporary patch. Hopefully. We will figure out how to either fix evojax or stop using it for future experiments.

Also, fix a bug on the sexual reproduction colab where I was importing a function that I deleted, causing the colab to fail.

PiperOrigin-RevId: 638631920
@eleninisioti
Copy link

I had the same issue and in the end these versions worked for me :

jaxlib: 0.4.26+cuda12.cudnn89 (installed from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html )
jax: 0.4.26
flax: 0.8.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

8 participants