Skip to content

Commit 382c50b

Browse files
Add numba overload for pytensor.tensor.slinalg.block_diag
1 parent 1de23db commit 382c50b

File tree

4 files changed

+56
-6
lines changed

4 files changed

+56
-6
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pytensor.link.numba.dispatch import basic as numba_basic
1111
from pytensor.link.numba.dispatch.basic import numba_funcify
12-
from pytensor.tensor.slinalg import SolveTriangular
12+
from pytensor.tensor.slinalg import BlockDiagonalMatrix, SolveTriangular
1313

1414

1515
_PTR = ctypes.POINTER
@@ -273,3 +273,25 @@ def solve_triangular(a, b):
273273
return res
274274

275275
return solve_triangular
276+
277+
278+
@numba_funcify.register(BlockDiagonalMatrix)
279+
def numba_funcify_BlockDiagonalMatrix(op, node, **kwargs):
280+
dtype = node.outputs[0].dtype
281+
282+
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
283+
@numba_basic.numba_njit(inline="never")
284+
def block_diag(*arrs):
285+
shapes = np.array([a.shape for a in arrs], dtype=dtype)
286+
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
287+
out = np.zeros((out_shape[0], out_shape[1]))
288+
289+
r, c = 0, 0
290+
for arr, shape in zip(arrs, shapes):
291+
rr, cc = shape
292+
out[r : r + rr, c : c + cc] = arr
293+
r += rr
294+
c += cc
295+
return out
296+
297+
return block_diag

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def make_node(self, *matrices, name=None):
949949
return Apply(self, matrices, [out_type])
950950

951951
def perform(self, node, inputs, output_storage, params=None):
952-
dtype = largest_common_dtype(inputs)
952+
dtype = node.outputs[0].type.dtype
953953
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)
954954

955955

tests/link/numba/test_slinalg.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
from scipy import linalg
56

67
import pytensor
78
import pytensor.tensor as pt
@@ -10,7 +11,6 @@
1011

1112
numba = pytest.importorskip("numba")
1213

13-
1414
ATOL = 0 if config.floatX.endswith("64") else 1e-6
1515
RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6
1616
rng = np.random.default_rng(42849)
@@ -102,3 +102,30 @@ def test_solve_triangular_raises_on_nan_inf(value):
102102
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
103103
):
104104
f(A_tri, b)
105+
106+
107+
def test_block_diag():
108+
A = pt.matrix("A")
109+
B = pt.matrix("B")
110+
C = pt.matrix("C")
111+
D = pt.matrix("D")
112+
X = pt.linalg.block_diag(A, B, C, D)
113+
f = pytensor.function([A, B, C, D], X, mode="NUMBA")
114+
115+
A_val = np.random.normal(size=(5, 5))
116+
B_val = np.random.normal(size=(3, 3))
117+
C_val = np.random.normal(size=(2, 2))
118+
D_val = np.random.normal(size=(4, 4))
119+
120+
X_val = f(A_val, B_val, C_val, D_val)
121+
np.testing.assert_allclose(
122+
np.block([[A_val, np.zeros((5, 3))], [np.zeros((3, 5)), B_val]]), X_val[:8, :8]
123+
)
124+
np.testing.assert_allclose(
125+
np.block([[C_val, np.zeros((2, 4))], [np.zeros((4, 2)), D_val]]), X_val[8:, 8:]
126+
)
127+
np.testing.assert_allclose(np.zeros((8, 6)), X_val[:8, 8:])
128+
np.testing.assert_allclose(np.zeros((6, 8)), X_val[8:, :8])
129+
130+
X_sp = linalg.block_diag(A_val, B_val, C_val, D_val)
131+
np.testing.assert_allclose(X_val, X_sp)

tests/tensor/test_slinalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ def test_solve_discrete_are_grad():
665665

666666

667667
def test_block_diagonal():
668-
matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])]
669-
result = block_diag(*matrices)
670-
np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(*matrices))
668+
A = np.array([[1.0, 2.0], [3.0, 4.0]])
669+
B = np.array([[5.0, 6.0], [7.0, 8.0]])
670+
result = block_diag(A, B, name="X")
671+
np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B))

0 commit comments

Comments
 (0)