From 71fe3b9e0db01e61fad2f28aefa5a8a1a17e5ca4 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 7 Jun 2024 15:06:41 +0800 Subject: [PATCH 01/14] Edited xfail reason for some 3 JAX tests --- tests/link/jax/test_extra_ops.py | 2 +- tests/link/jax/test_nlinalg.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index c9920b31cc..dff882275b 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -65,7 +65,7 @@ def test_extra_ops(): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="JAX Numpy API does not support dynamic shapes", ) def test_extra_ops_omni(): a = matrix("a") diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 3a64fda364..693810e5f5 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -82,7 +82,7 @@ def assert_fn(x, y): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="JAX Numpy API does not support dynamic shapes", ) def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to @@ -96,7 +96,7 @@ def test_jax_basic_multiout_omni(): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="`dot` -> `Gemv` optimization is incompatible with JAX", ) def test_tensor_basics(): y = vector("y") From a170d33febd4dc71b3e29ea828299f9eba2db10b Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 17 Jun 2024 21:38:39 +0800 Subject: [PATCH 02/14] Update tests/link/jax/test_extra_ops.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/link/jax/test_extra_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index dff882275b..94c442b165 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -67,7 +67,7 @@ def test_extra_ops(): version_parse(jax.__version__) >= version_parse("0.2.12"), reason="JAX Numpy API does not support dynamic shapes", ) -def test_extra_ops_omni(): +def test_extra_ops_dynamic_shapes(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) From d0fdd6ef83702c730776b735c1436075afd5c1cd Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 17 Jun 2024 13:48:57 +0000 Subject: [PATCH 03/14] Remove expected fail mark for test that actually passes --- tests/link/jax/test_nlinalg.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 693810e5f5..44e4f38923 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -94,10 +94,6 @@ def test_jax_basic_multiout_omni(): compare_jax_and_py(out_fg, [np.r_[1, 2]]) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="`dot` -> `Gemv` optimization is incompatible with JAX", -) def test_tensor_basics(): y = vector("y") y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) From 68e245f5ddf2e0763fdd406833d856c268869525 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 28 Jun 2024 14:54:45 +0800 Subject: [PATCH 04/14] Changed reason for xfail --- tests/link/jax/test_nlinalg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 44e4f38923..a168a38d6a 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -13,7 +13,7 @@ from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor.math import MaxAndArgmax, maximum from pytensor.tensor.math import max as pt_max -from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector +from pytensor.tensor.type import matrix, scalar, tensor3, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -82,12 +82,12 @@ def assert_fn(x, y): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="JAX Numpy API does not support dynamic shapes", + reason="Operation leads to JAX Tracer object is used in a context where a Python integer is expected (jax.errors.TracerIntegerConversionError).", ) def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to # another `Op` - x = dvector() + x = vector(dtype="float64", shape=(2,)) mx, amx = MaxAndArgmax([0])(x) out = mx * amx out_fg = FunctionGraph([x], [out]) From 36797913849037d74c1b4eb830530882ff3bdb18 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 7 Jun 2024 15:06:41 +0800 Subject: [PATCH 05/14] Edited xfail reason for some 3 JAX tests --- tests/link/jax/test_extra_ops.py | 2 +- tests/link/jax/test_nlinalg.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index c9920b31cc..dff882275b 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -65,7 +65,7 @@ def test_extra_ops(): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="JAX Numpy API does not support dynamic shapes", ) def test_extra_ops_omni(): a = matrix("a") diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 3a64fda364..693810e5f5 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -82,7 +82,7 @@ def assert_fn(x, y): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="JAX Numpy API does not support dynamic shapes", ) def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to @@ -96,7 +96,7 @@ def test_jax_basic_multiout_omni(): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="`dot` -> `Gemv` optimization is incompatible with JAX", ) def test_tensor_basics(): y = vector("y") From ae0801d0e5c62b2f324fd8adb99d1d36b65bee2b Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 17 Jun 2024 13:48:57 +0000 Subject: [PATCH 06/14] Remove expected fail mark for test that actually passes --- tests/link/jax/test_nlinalg.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 693810e5f5..44e4f38923 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -94,10 +94,6 @@ def test_jax_basic_multiout_omni(): compare_jax_and_py(out_fg, [np.r_[1, 2]]) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="`dot` -> `Gemv` optimization is incompatible with JAX", -) def test_tensor_basics(): y = vector("y") y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) From f287d6c483d21034b8de6731602a1f8a4221657c Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 17 Jun 2024 21:38:39 +0800 Subject: [PATCH 07/14] Update tests/link/jax/test_extra_ops.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/link/jax/test_extra_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index dff882275b..94c442b165 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -67,7 +67,7 @@ def test_extra_ops(): version_parse(jax.__version__) >= version_parse("0.2.12"), reason="JAX Numpy API does not support dynamic shapes", ) -def test_extra_ops_omni(): +def test_extra_ops_dynamic_shapes(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) From 8b05952bef868253adf86273f2f1e14515685d4e Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 28 Jun 2024 14:54:45 +0800 Subject: [PATCH 08/14] Changed reason for xfail --- tests/link/jax/test_nlinalg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 44e4f38923..a168a38d6a 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -13,7 +13,7 @@ from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor.math import MaxAndArgmax, maximum from pytensor.tensor.math import max as pt_max -from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector +from pytensor.tensor.type import matrix, scalar, tensor3, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -82,12 +82,12 @@ def assert_fn(x, y): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="JAX Numpy API does not support dynamic shapes", + reason="Operation leads to JAX Tracer object is used in a context where a Python integer is expected (jax.errors.TracerIntegerConversionError).", ) def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to # another `Op` - x = dvector() + x = vector(dtype="float64", shape=(2,)) mx, amx = MaxAndArgmax([0])(x) out = mx * amx out_fg = FunctionGraph([x], [out]) From 349beef562d754158bebc3fb44dc7f2aceb1aa0a Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 7 Jun 2024 15:06:41 +0800 Subject: [PATCH 09/14] Edited xfail reason for some 3 JAX tests --- tests/link/jax/test_extra_ops.py | 2 +- tests/link/jax/test_nlinalg.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index c9920b31cc..dff882275b 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -65,7 +65,7 @@ def test_extra_ops(): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="JAX Numpy API does not support dynamic shapes", ) def test_extra_ops_omni(): a = matrix("a") diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 0bc6749448..4867dbc4e3 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -82,7 +82,7 @@ def assert_fn(x, y): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="JAX Numpy API does not support dynamic shapes", ) def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to @@ -97,7 +97,7 @@ def test_jax_basic_multiout_omni(): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", + reason="`dot` -> `Gemv` optimization is incompatible with JAX", ) def test_tensor_basics(): y = vector("y") From 8ed6ab2d3fa3330f60b631485127d4d623c62b0f Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 17 Jun 2024 13:48:57 +0000 Subject: [PATCH 10/14] Remove expected fail mark for test that actually passes --- tests/link/jax/test_nlinalg.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 4867dbc4e3..5bd9dac8e4 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -95,10 +95,6 @@ def test_jax_basic_multiout_omni(): compare_jax_and_py(out_fg, [np.r_[1, 2]]) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="`dot` -> `Gemv` optimization is incompatible with JAX", -) def test_tensor_basics(): y = vector("y") y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) From b50a906816c61bf43a86ebde93c5ef13a5c9af21 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Mon, 17 Jun 2024 21:38:39 +0800 Subject: [PATCH 11/14] Update tests/link/jax/test_extra_ops.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/link/jax/test_extra_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index dff882275b..94c442b165 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -67,7 +67,7 @@ def test_extra_ops(): version_parse(jax.__version__) >= version_parse("0.2.12"), reason="JAX Numpy API does not support dynamic shapes", ) -def test_extra_ops_omni(): +def test_extra_ops_dynamic_shapes(): a = matrix("a") a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) From 89c7bff04b28ba0834f1b5afbbe8c27737ea687d Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 28 Jun 2024 14:54:45 +0800 Subject: [PATCH 12/14] Changed reason for xfail --- tests/link/jax/test_nlinalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 5bd9dac8e4..937d068c03 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -13,7 +13,7 @@ from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.math import max as pt_max -from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector +from pytensor.tensor.type import matrix, scalar, tensor3, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -82,7 +82,7 @@ def assert_fn(x, y): @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="JAX Numpy API does not support dynamic shapes", + reason="Operation leads to JAX Tracer object is used in a context where a Python integer is expected (jax.errors.TracerIntegerConversionError).", ) def test_jax_basic_multiout_omni(): # Test that a single output of a multi-output `Op` can be used as input to From 2c4a378fc475d82aaa9221abe1b3fd28b7a430e7 Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 3 Jul 2024 16:07:56 +0800 Subject: [PATCH 13/14] Rebase to remove MaxAndArgmax --- tests/link/jax/test_nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 937d068c03..d7484c5672 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -13,7 +13,7 @@ from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.math import max as pt_max -from pytensor.tensor.type import matrix, scalar, tensor3, vector +from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector from tests.link.jax.test_basic import compare_jax_and_py From be9c9f81c89227f59d154b1626a0784663c80f5f Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Sat, 6 Jul 2024 16:14:49 +0800 Subject: [PATCH 14/14] Edited implementation of argmax. Changed test name for multiomni_output --- pytensor/link/jax/dispatch/nlinalg.py | 11 +++++------ tests/link/jax/test_nlinalg.py | 7 +------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index c79178a74c..0235f3c5db 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +import numpy as np from pytensor.link.jax.dispatch import jax_funcify from pytensor.tensor.blas import BatchedDot @@ -137,12 +138,10 @@ def argmax(x): # NumPy does not support multiple axes for argmax; this is a # work-around - keep_axes = jnp.array( - [i for i in range(x.ndim) if i not in axes], dtype="int64" - ) + keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") # Not-reduced axes in front transposed_x = jnp.transpose( - x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) + x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64")))) ) kept_shape = transposed_x.shape[: len(keep_axes)] reduced_shape = transposed_x.shape[len(keep_axes) :] @@ -151,9 +150,9 @@ def argmax(x): # Otherwise reshape would complain citing float arg new_shape = ( *kept_shape, - jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), + np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"), ) - reshaped_x = transposed_x.reshape(new_shape) + reshaped_x = transposed_x.reshape(tuple(new_shape)) max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index d7484c5672..4340b395cb 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -1,6 +1,5 @@ import numpy as np import pytest -from packaging.version import parse as version_parse from pytensor.compile.function import function from pytensor.compile.mode import Mode @@ -80,11 +79,7 @@ def assert_fn(x, y): compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Operation leads to JAX Tracer object is used in a context where a Python integer is expected (jax.errors.TracerIntegerConversionError).", -) -def test_jax_basic_multiout_omni(): +def test_jax_max_and_argmax(): # Test that a single output of a multi-output `Op` can be used as input to # another `Op` x = dvector()