Skip to content

Fix JAX implementation of Argmax #809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pytensor/link/jax/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import jax.numpy as jnp
import numpy as np

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
Expand Down Expand Up @@ -137,12 +138,10 @@ def argmax(x):

# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
Expand All @@ -151,9 +150,9 @@ def argmax(x):
# Otherwise reshape would complain citing float arg
new_shape = (
*kept_shape,
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(new_shape)
reshaped_x = transposed_x.reshape(tuple(new_shape))

max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")

Expand Down
4 changes: 2 additions & 2 deletions tests/link/jax/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def test_extra_ops():

@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
reason="JAX Numpy API does not support dynamic shapes",
)
def test_extra_ops_omni():
def test_extra_ops_dynamic_shapes():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))

Expand Down
11 changes: 1 addition & 10 deletions tests/link/jax/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import pytest
from packaging.version import parse as version_parse

from pytensor.compile.function import function
from pytensor.compile.mode import Mode
Expand Down Expand Up @@ -80,11 +79,7 @@ def assert_fn(x, y):
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_basic_multiout_omni():
def test_jax_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
Expand All @@ -95,10 +90,6 @@ def test_jax_basic_multiout_omni():
compare_jax_and_py(out_fg, [np.r_[1, 2]])


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_tensor_basics():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
Expand Down
Loading