Duplicating the space gives an error: 'jax.core' has no attribute 'NamedShape'

#2
by dmathewwws - opened

Hi,

I duplicated this space: I get the following build error:

Traceback (most recent call last):
  File "/home/user/app/app.py", line 15, in <module>
    from whisper_jax import FlaxWhisperPipline
  File "/home/user/app/whisper_jax/__init__.py", line 18, in <module>
    from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
  File "/home/user/app/whisper_jax/modeling_flax_whisper.py", line 57, in <module>
    from whisper_jax import layers
  File "/home/user/app/whisper_jax/layers.py", line 63, in <module>
    def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/deprecations.py", line 55, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'NamedShape'

I am running this on Google TPU v5e - 1x1. Also, I am curious why whisper_jax is included in a directory vs being included in requirements.txt

Sign up or log in to comment