Skip to content

Commit 491111b

Browse files
add jax overload for pytensor.tensor.slinalg.block_diag
1 parent 382c50b commit 491111b

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import jax
22

33
from pytensor.link.jax.dispatch.basic import jax_funcify
4-
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
4+
from pytensor.tensor.slinalg import (
5+
BlockDiagonalMatrix,
6+
Cholesky,
7+
Solve,
8+
SolveTriangular,
9+
)
510

611

712
@jax_funcify.register(Cholesky)
@@ -45,3 +50,11 @@ def solve_triangular(A, b):
4550
)
4651

4752
return solve_triangular
53+
54+
55+
@jax_funcify.register(BlockDiagonalMatrix)
56+
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
57+
def block_diag(*inputs):
58+
return jax.scipy.linalg.block_diag(*inputs)
59+
60+
return block_diag

tests/link/jax/test_slinalg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,23 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
129129
np.arange(10).astype(config.floatX),
130130
],
131131
)
132+
133+
134+
def test_jax_block_diag():
135+
A = matrix("A")
136+
B = matrix("B")
137+
C = matrix("C")
138+
D = matrix("D")
139+
140+
out = pt_slinalg.block_diag(A, B, C, D)
141+
142+
out_fg = FunctionGraph([A, B, C, D], [out])
143+
compare_jax_and_py(
144+
out_fg,
145+
[
146+
np.random.normal(size=(5, 5)).astype(config.floatX),
147+
np.random.normal(size=(3, 3)).astype(config.floatX),
148+
np.random.normal(size=(2, 2)).astype(config.floatX),
149+
np.random.normal(size=(4, 4)).astype(config.floatX),
150+
],
151+
)

0 commit comments

Comments
 (0)