Skip to content

Commit f35ce26

Browse files
HangenYuuricardoV94
authored andcommitted
Reorganized JAX link folder structure
1 parent 308bc01 commit f35ce26

File tree

7 files changed

+173
-146
lines changed

7 files changed

+173
-146
lines changed

pytensor/link/jax/dispatch/__init__.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
33

44
# Load dispatch specializations
5-
import pytensor.link.jax.dispatch.scalar
6-
import pytensor.link.jax.dispatch.tensor_basic
7-
import pytensor.link.jax.dispatch.subtensor
8-
import pytensor.link.jax.dispatch.shape
5+
import pytensor.link.jax.dispatch.blas
6+
import pytensor.link.jax.dispatch.blockwise
7+
import pytensor.link.jax.dispatch.elemwise
98
import pytensor.link.jax.dispatch.extra_ops
9+
import pytensor.link.jax.dispatch.math
1010
import pytensor.link.jax.dispatch.nlinalg
11-
import pytensor.link.jax.dispatch.slinalg
1211
import pytensor.link.jax.dispatch.random
13-
import pytensor.link.jax.dispatch.elemwise
12+
import pytensor.link.jax.dispatch.scalar
1413
import pytensor.link.jax.dispatch.scan
15-
import pytensor.link.jax.dispatch.sparse
16-
import pytensor.link.jax.dispatch.blockwise
14+
import pytensor.link.jax.dispatch.shape
15+
import pytensor.link.jax.dispatch.slinalg
1716
import pytensor.link.jax.dispatch.sort
17+
import pytensor.link.jax.dispatch.sparse
18+
import pytensor.link.jax.dispatch.subtensor
19+
import pytensor.link.jax.dispatch.tensor_basic
1820

1921
# isort: on

pytensor/link/jax/dispatch/blas.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.blas import BatchedDot
5+
6+
7+
@jax_funcify.register(BatchedDot)
8+
def jax_funcify_BatchedDot(op, **kwargs):
9+
def batched_dot(a, b):
10+
if a.shape[0] != b.shape[0]:
11+
raise TypeError("Shapes must match along the first dimension of BatchedDot")
12+
return jnp.matmul(a, b)
13+
14+
return batched_dot

pytensor/link/jax/dispatch/math.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import jax.numpy as jnp
2+
import numpy as np
3+
4+
from pytensor.link.jax.dispatch import jax_funcify
5+
from pytensor.tensor.math import Argmax, Dot, Max
6+
7+
8+
@jax_funcify.register(Dot)
9+
def jax_funcify_Dot(op, **kwargs):
10+
def dot(x, y):
11+
return jnp.dot(x, y)
12+
13+
return dot
14+
15+
16+
@jax_funcify.register(Max)
17+
def jax_funcify_Max(op, **kwargs):
18+
axis = op.axis
19+
20+
def max(x):
21+
max_res = jnp.max(x, axis)
22+
23+
return max_res
24+
25+
return max
26+
27+
28+
@jax_funcify.register(Argmax)
29+
def jax_funcify_Argmax(op, **kwargs):
30+
axis = op.axis
31+
32+
def argmax(x):
33+
if axis is None:
34+
axes = tuple(range(x.ndim))
35+
else:
36+
axes = tuple(int(ax) for ax in axis)
37+
38+
# NumPy does not support multiple axes for argmax; this is a
39+
# work-around
40+
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
41+
# Not-reduced axes in front
42+
transposed_x = jnp.transpose(
43+
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
44+
)
45+
kept_shape = transposed_x.shape[: len(keep_axes)]
46+
reduced_shape = transposed_x.shape[len(keep_axes) :]
47+
48+
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
49+
# Otherwise reshape would complain citing float arg
50+
new_shape = (
51+
*kept_shape,
52+
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
53+
)
54+
reshaped_x = transposed_x.reshape(tuple(new_shape))
55+
56+
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
57+
58+
return max_idx_res
59+
60+
return argmax

pytensor/link/jax/dispatch/nlinalg.py

-68
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import jax.numpy as jnp
2-
import numpy as np
32

43
from pytensor.link.jax.dispatch import jax_funcify
5-
from pytensor.tensor.blas import BatchedDot
6-
from pytensor.tensor.math import Argmax, Dot, Max
74
from pytensor.tensor.nlinalg import (
85
SVD,
96
Det,
@@ -80,14 +77,6 @@ def qr_full(x, mode=mode):
8077
return qr_full
8178

8279

83-
@jax_funcify.register(Dot)
84-
def jax_funcify_Dot(op, **kwargs):
85-
def dot(x, y):
86-
return jnp.dot(x, y)
87-
88-
return dot
89-
90-
9180
@jax_funcify.register(MatrixPinv)
9281
def jax_funcify_Pinv(op, **kwargs):
9382
def pinv(x):
@@ -96,66 +85,9 @@ def pinv(x):
9685
return pinv
9786

9887

99-
@jax_funcify.register(BatchedDot)
100-
def jax_funcify_BatchedDot(op, **kwargs):
101-
def batched_dot(a, b):
102-
if a.shape[0] != b.shape[0]:
103-
raise TypeError("Shapes must match in the 0-th dimension")
104-
return jnp.matmul(a, b)
105-
106-
return batched_dot
107-
108-
10988
@jax_funcify.register(KroneckerProduct)
11089
def jax_funcify_KroneckerProduct(op, **kwargs):
11190
def _kron(x, y):
11291
return jnp.kron(x, y)
11392

11493
return _kron
115-
116-
117-
@jax_funcify.register(Max)
118-
def jax_funcify_Max(op, **kwargs):
119-
axis = op.axis
120-
121-
def max(x):
122-
max_res = jnp.max(x, axis)
123-
124-
return max_res
125-
126-
return max
127-
128-
129-
@jax_funcify.register(Argmax)
130-
def jax_funcify_Argmax(op, **kwargs):
131-
axis = op.axis
132-
133-
def argmax(x):
134-
if axis is None:
135-
axes = tuple(range(x.ndim))
136-
else:
137-
axes = tuple(int(ax) for ax in axis)
138-
139-
# NumPy does not support multiple axes for argmax; this is a
140-
# work-around
141-
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
142-
# Not-reduced axes in front
143-
transposed_x = jnp.transpose(
144-
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
145-
)
146-
kept_shape = transposed_x.shape[: len(keep_axes)]
147-
reduced_shape = transposed_x.shape[len(keep_axes) :]
148-
149-
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
150-
# Otherwise reshape would complain citing float arg
151-
new_shape = (
152-
*kept_shape,
153-
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
154-
)
155-
reshaped_x = transposed_x.reshape(tuple(new_shape))
156-
157-
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
158-
159-
return max_idx_res
160-
161-
return argmax

tests/link/jax/test_blas.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.compile.function import function
5+
from pytensor.compile.mode import Mode
6+
from pytensor.configdefaults import config
7+
from pytensor.graph.fg import FunctionGraph
8+
from pytensor.graph.op import get_test_value
9+
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
10+
from pytensor.link.jax import JAXLinker
11+
from pytensor.tensor import blas as pt_blas
12+
from pytensor.tensor.type import tensor3
13+
from tests.link.jax.test_basic import compare_jax_and_py
14+
15+
16+
def test_jax_BatchedDot():
17+
# tensor3 . tensor3
18+
a = tensor3("a")
19+
a.tag.test_value = (
20+
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
21+
)
22+
b = tensor3("b")
23+
b.tag.test_value = (
24+
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
25+
)
26+
out = pt_blas.BatchedDot()(a, b)
27+
fgraph = FunctionGraph([a, b], [out])
28+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
29+
30+
# A dimension mismatch should raise a TypeError for compatibility
31+
inputs = [get_test_value(a)[:-1], get_test_value(b)]
32+
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
33+
jax_mode = Mode(JAXLinker(), opts)
34+
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
35+
with pytest.raises(TypeError):
36+
pytensor_jax_fn(*inputs)

tests/link/jax/test_math.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.configdefaults import config
5+
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.graph.op import get_test_value
7+
from pytensor.tensor.math import Argmax, Max, maximum
8+
from pytensor.tensor.math import max as pt_max
9+
from pytensor.tensor.type import dvector, matrix, scalar, vector
10+
from tests.link.jax.test_basic import compare_jax_and_py
11+
12+
13+
jax = pytest.importorskip("jax")
14+
15+
16+
def test_jax_max_and_argmax():
17+
# Test that a single output of a multi-output `Op` can be used as input to
18+
# another `Op`
19+
x = dvector()
20+
mx = Max([0])(x)
21+
amx = Argmax([0])(x)
22+
out = mx * amx
23+
out_fg = FunctionGraph([x], [out])
24+
compare_jax_and_py(out_fg, [np.r_[1, 2]])
25+
26+
27+
def test_dot():
28+
y = vector("y")
29+
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
30+
x = vector("x")
31+
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
32+
A = matrix("A")
33+
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
34+
alpha = scalar("alpha")
35+
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
36+
beta = scalar("beta")
37+
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
38+
39+
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
40+
# optimizations are turned on; however, when using JAX mode, it should
41+
# leave the expression alone.
42+
out = y.dot(alpha * A).dot(x) + beta * y
43+
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
44+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
45+
46+
out = maximum(y, x)
47+
fgraph = FunctionGraph([y, x], [out])
48+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
49+
50+
out = pt_max(y)
51+
fgraph = FunctionGraph([y], [out])
52+
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

tests/link/jax/test_nlinalg.py

+1-70
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,16 @@
22
import pytest
33

44
from pytensor.compile.function import function
5-
from pytensor.compile.mode import Mode
65
from pytensor.configdefaults import config
76
from pytensor.graph.fg import FunctionGraph
8-
from pytensor.graph.op import get_test_value
9-
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
10-
from pytensor.link.jax import JAXLinker
11-
from pytensor.tensor import blas as pt_blas
127
from pytensor.tensor import nlinalg as pt_nlinalg
13-
from pytensor.tensor.math import Argmax, Max, maximum
14-
from pytensor.tensor.math import max as pt_max
15-
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
8+
from pytensor.tensor.type import matrix
169
from tests.link.jax.test_basic import compare_jax_and_py
1710

1811

1912
jax = pytest.importorskip("jax")
2013

2114

22-
def test_jax_BatchedDot():
23-
# tensor3 . tensor3
24-
a = tensor3("a")
25-
a.tag.test_value = (
26-
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
27-
)
28-
b = tensor3("b")
29-
b.tag.test_value = (
30-
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
31-
)
32-
out = pt_blas.BatchedDot()(a, b)
33-
fgraph = FunctionGraph([a, b], [out])
34-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
35-
36-
# A dimension mismatch should raise a TypeError for compatibility
37-
inputs = [get_test_value(a)[:-1], get_test_value(b)]
38-
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
39-
jax_mode = Mode(JAXLinker(), opts)
40-
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
41-
with pytest.raises(TypeError):
42-
pytensor_jax_fn(*inputs)
43-
44-
4515
def test_jax_basic_multiout():
4616
rng = np.random.default_rng(213234)
4717

@@ -79,45 +49,6 @@ def assert_fn(x, y):
7949
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
8050

8151

82-
def test_jax_max_and_argmax():
83-
# Test that a single output of a multi-output `Op` can be used as input to
84-
# another `Op`
85-
x = dvector()
86-
mx = Max([0])(x)
87-
amx = Argmax([0])(x)
88-
out = mx * amx
89-
out_fg = FunctionGraph([x], [out])
90-
compare_jax_and_py(out_fg, [np.r_[1, 2]])
91-
92-
93-
def test_tensor_basics():
94-
y = vector("y")
95-
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
96-
x = vector("x")
97-
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
98-
A = matrix("A")
99-
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
100-
alpha = scalar("alpha")
101-
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
102-
beta = scalar("beta")
103-
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
104-
105-
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
106-
# optimizations are turned on; however, when using JAX mode, it should
107-
# leave the expression alone.
108-
out = y.dot(alpha * A).dot(x) + beta * y
109-
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
110-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
111-
112-
out = maximum(y, x)
113-
fgraph = FunctionGraph([y, x], [out])
114-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
115-
116-
out = pt_max(y)
117-
fgraph = FunctionGraph([y], [out])
118-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
119-
120-
12152
def test_pinv():
12253
x = matrix("x")
12354
x_inv = pt_nlinalg.pinv(x)

0 commit comments

Comments
 (0)