AttributeError: module 'jaxlib.xla_extension' has no attribute 'DeviceArrayBase'

1.2k views Asked by At

I'm trying to run a Colab notebook for image generation with JAX and ran into the following error:


    WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and       rerun for more info.)

    ---------------------------------------------------------------------------

    AttributeError                            Traceback (most recent call last)

    <ipython-input-7-73b0723cc3af> in <cell line: 23>()
     21 import jax.numpy as jnp
     22 import jax.scipy as jsp
    ---> 23 import jaxtorch
     24 from jaxtorch import PRNG, Context, Module, nn, init
     25 from tqdm import tqdm

    3 frames

    /content/./jax-guided-diffusion/jaxtorch/monkeypatches.py in register(**kwargs)
     16             print(f'Not monkeypatching DeviceArray and Tracer with `{attr}`, because that method is already implemented.', file=sys.stderr)
     17             continue
    ---> 18         setattr(jaxlib.xla_extension.DeviceArrayBase, attr, fun)
     19         setattr(jax.interpreters.xla.DeviceArray, attr, fun)
     20         setattr(jax.core.Tracer, attr, fun)

    AttributeError: module 'jaxlib.xla_extension' has no attribute 'DeviceArrayBase'

I tried to solve this problem by using different JAX versions and every GPU Colab offers but couldn't find a solution. I'd really appricate any help on this!

Link to the nootebook---> click

1

There are 1 answers

2
jakevdp On

DeviceArray and related types were deprecated and removed in JAX v0.4.1 (See the Changelog). It looks like the version of jaxtorch you are using is not compatible with more recent JAX versions. If there is no newer version of jaxtorch available, I would suggest trying to use it with JAX version 0.3.25 or older.