Skip to content

Commit d512271

Browse files
Fix JAX implementation of Argmax (#809)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 31bf682 commit d512271

File tree

3 files changed

+8
-18
lines changed

3 files changed

+8
-18
lines changed

pytensor/link/jax/dispatch/nlinalg.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import jax.numpy as jnp
2+
import numpy as np
23

34
from pytensor.link.jax.dispatch import jax_funcify
45
from pytensor.tensor.blas import BatchedDot
@@ -137,12 +138,10 @@ def argmax(x):
137138

138139
# NumPy does not support multiple axes for argmax; this is a
139140
# work-around
140-
keep_axes = jnp.array(
141-
[i for i in range(x.ndim) if i not in axes], dtype="int64"
142-
)
141+
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
143142
# Not-reduced axes in front
144143
transposed_x = jnp.transpose(
145-
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
144+
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
146145
)
147146
kept_shape = transposed_x.shape[: len(keep_axes)]
148147
reduced_shape = transposed_x.shape[len(keep_axes) :]
@@ -151,9 +150,9 @@ def argmax(x):
151150
# Otherwise reshape would complain citing float arg
152151
new_shape = (
153152
*kept_shape,
154-
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
153+
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
155154
)
156-
reshaped_x = transposed_x.reshape(new_shape)
155+
reshaped_x = transposed_x.reshape(tuple(new_shape))
157156

158157
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
159158

tests/link/jax/test_extra_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def test_extra_ops():
6565

6666
@pytest.mark.xfail(
6767
version_parse(jax.__version__) >= version_parse("0.2.12"),
68-
reason="Omnistaging cannot be disabled",
68+
reason="JAX Numpy API does not support dynamic shapes",
6969
)
70-
def test_extra_ops_omni():
70+
def test_extra_ops_dynamic_shapes():
7171
a = matrix("a")
7272
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
7373

tests/link/jax/test_nlinalg.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import pytest
3-
from packaging.version import parse as version_parse
43

54
from pytensor.compile.function import function
65
from pytensor.compile.mode import Mode
@@ -80,11 +79,7 @@ def assert_fn(x, y):
8079
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
8180

8281

83-
@pytest.mark.xfail(
84-
version_parse(jax.__version__) >= version_parse("0.2.12"),
85-
reason="Omnistaging cannot be disabled",
86-
)
87-
def test_jax_basic_multiout_omni():
82+
def test_jax_max_and_argmax():
8883
# Test that a single output of a multi-output `Op` can be used as input to
8984
# another `Op`
9085
x = dvector()
@@ -95,10 +90,6 @@ def test_jax_basic_multiout_omni():
9590
compare_jax_and_py(out_fg, [np.r_[1, 2]])
9691

9792

98-
@pytest.mark.xfail(
99-
version_parse(jax.__version__) >= version_parse("0.2.12"),
100-
reason="Omnistaging cannot be disabled",
101-
)
10293
def test_tensor_basics():
10394
y = vector("y")
10495
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)

0 commit comments

Comments
 (0)