Skip to content

Commit da10cf3

Browse files
authored
Fix JAX CPU tests - saved_model_export.py (keras-team#20962)
With JAX 0.5.1, `jax2tf` exports XLA that is not compatible with TensorFlow 2.18, making the `saved_model_export.py` tests fail. Since Tensorflow 2.19 is not out yet, we pin JAX to 0.5.0 for now.
1 parent 6c3dd68 commit da10cf3

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ torch==2.6.0+cpu
1010
torch-xla==2.6.0;sys_platform != 'darwin'
1111

1212
# Jax.
13-
jax[cpu]
13+
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test.
14+
# Note that we test against the latest JAX on GPU.
15+
jax[cpu]==0.5.0
1416
flax
1517

1618
# Common deps.

0 commit comments

Comments
 (0)