diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index c79178a74c..0235f3c5db 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +import numpy as np from pytensor.link.jax.dispatch import jax_funcify from pytensor.tensor.blas import BatchedDot @@ -137,12 +138,10 @@ def argmax(x): # NumPy does not support multiple axes for argmax; this is a # work-around - keep_axes = jnp.array( - [i for i in range(x.ndim) if i not in axes], dtype="int64" - ) + keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") # Not-reduced axes in front transposed_x = jnp.transpose( - x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) + x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64")))) ) kept_shape = transposed_x.shape[: len(keep_axes)] reduced_shape = transposed_x.shape[len(keep_axes) :] @@ -151,9 +150,9 @@ def argmax(x): # Otherwise reshape would complain citing float arg new_shape = ( *kept_shape, - jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), + np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"), ) - reshaped_x = transposed_x.reshape(new_shape) + reshaped_x = transposed_x.reshape(tuple(new_shape)) max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index c9920b31cc..94c442b165 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -65,9 +65,9 @@ def test_extra_ops(): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="JAX Numpy API does not support dynamic shapes", ) -def test_extra_ops_omni(): +def test_extra_ops_dynamic_shapes(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 0bc6749448..4340b395cb 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -1,6 +1,5 @@ import numpy as np import pytest -from packaging.version import parse as version_parse from pytensor.compile.function import function from pytensor.compile.mode import Mode @@ -80,11 +79,7 @@ def assert_fn(x, y): compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_basic_multiout_omni(): +def test_jax_max_and_argmax(): # Test that a single output of a multi-output `Op` can be used as input to # another `Op` x = dvector() @@ -95,10 +90,6 @@ def test_jax_basic_multiout_omni(): compare_jax_and_py(out_fg, [np.r_[1, 2]]) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) def test_tensor_basics(): y = vector("y") y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)