Skip to content

Commit e98cbbc

Browse files
Numba dispatch for LU ops
1 parent 679b2f7 commit e98cbbc

File tree

7 files changed

+754
-114
lines changed

7 files changed

+754
-114
lines changed

pytensor/link/numba/dispatch/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
7676
message=(
7777
"(\x1b\\[1m)*" # ansi escape code for bold text
7878
"Cannot cache compiled function "
79-
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
79+
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
8080
"as it uses dynamic globals"
8181
),
8282
category=NumbaWarning,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from collections.abc import Callable
2+
from typing import cast as typing_cast
3+
4+
import numpy as np
5+
from numba import njit as numba_njit
6+
from numba.core.extending import overload
7+
from numba.np.linalg import ensure_lapack
8+
from scipy import linalg
9+
10+
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
11+
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
12+
13+
14+
@numba_njit
15+
def _pivot_to_permutation(p, dtype):
16+
p_inv = np.arange(len(p)).astype(dtype)
17+
for i in range(len(p)):
18+
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
19+
return p_inv
20+
21+
22+
@numba_njit
23+
def _lu_factor_to_lu(a, dtype, overwrite_a):
24+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
25+
26+
L = np.eye(A_copy.shape[-1], dtype=dtype)
27+
L += np.tril(A_copy, k=-1)
28+
U = np.triu(A_copy)
29+
30+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
31+
IPIV = IPIV - 1
32+
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
33+
perm = np.argsort(p_inv)
34+
35+
return perm, L, U
36+
37+
38+
def _lu_1(
39+
a: np.ndarray,
40+
permute_l: bool,
41+
check_finite: bool,
42+
p_indices: bool,
43+
overwrite_a: bool,
44+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
45+
"""
46+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
47+
48+
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
49+
array of row swaps, such that L[perm] @ U = A.
50+
"""
51+
return typing_cast(
52+
tuple[np.ndarray, np.ndarray, np.ndarray],
53+
linalg.lu(
54+
a,
55+
permute_l=permute_l,
56+
check_finite=check_finite,
57+
p_indices=p_indices,
58+
overwrite_a=overwrite_a,
59+
),
60+
)
61+
62+
63+
def _lu_2(
64+
a: np.ndarray,
65+
permute_l: bool,
66+
check_finite: bool,
67+
p_indices: bool,
68+
overwrite_a: bool,
69+
) -> tuple[np.ndarray, np.ndarray]:
70+
"""
71+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
72+
73+
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
74+
permuted L matrix, PL = P @ L.
75+
"""
76+
return typing_cast(
77+
tuple[np.ndarray, np.ndarray],
78+
linalg.lu(
79+
a,
80+
permute_l=permute_l,
81+
check_finite=check_finite,
82+
p_indices=p_indices,
83+
overwrite_a=overwrite_a,
84+
),
85+
)
86+
87+
88+
def _lu_3(
89+
a: np.ndarray,
90+
permute_l: bool,
91+
check_finite: bool,
92+
p_indices: bool,
93+
overwrite_a: bool,
94+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
95+
"""
96+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
97+
98+
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
99+
matrix, P @ L @ U = A.
100+
"""
101+
return typing_cast(
102+
tuple[np.ndarray, np.ndarray, np.ndarray],
103+
linalg.lu(
104+
a,
105+
permute_l=permute_l,
106+
check_finite=check_finite,
107+
p_indices=p_indices,
108+
overwrite_a=overwrite_a,
109+
),
110+
)
111+
112+
113+
@overload(_lu_1)
114+
def lu_impl_1(
115+
a: np.ndarray,
116+
permute_l: bool,
117+
check_finite: bool,
118+
p_indices: bool,
119+
overwrite_a: bool,
120+
) -> Callable[
121+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
122+
]:
123+
"""
124+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
125+
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
126+
"""
127+
ensure_lapack()
128+
_check_scipy_linalg_matrix(a, "lu")
129+
dtype = a.dtype
130+
131+
def impl(
132+
a: np.ndarray,
133+
permute_l: bool,
134+
check_finite: bool,
135+
p_indices: bool,
136+
overwrite_a: bool,
137+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
138+
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
139+
return perm, L, U
140+
141+
return impl
142+
143+
144+
@overload(_lu_2)
145+
def lu_impl_2(
146+
a: np.ndarray,
147+
permute_l: bool,
148+
check_finite: bool,
149+
p_indices: bool,
150+
overwrite_a: bool,
151+
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
152+
"""
153+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
154+
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
155+
"""
156+
157+
ensure_lapack()
158+
_check_scipy_linalg_matrix(a, "lu")
159+
dtype = a.dtype
160+
161+
def impl(
162+
a: np.ndarray,
163+
permute_l: bool,
164+
check_finite: bool,
165+
p_indices: bool,
166+
overwrite_a: bool,
167+
) -> tuple[np.ndarray, np.ndarray]:
168+
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
169+
PL = L[perm]
170+
171+
return PL, U
172+
173+
return impl
174+
175+
176+
@overload(_lu_3)
177+
def lu_impl_3(
178+
a: np.ndarray,
179+
permute_l: bool,
180+
check_finite: bool,
181+
p_indices: bool,
182+
overwrite_a: bool,
183+
) -> Callable[
184+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
185+
]:
186+
"""
187+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
188+
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
189+
"""
190+
ensure_lapack()
191+
_check_scipy_linalg_matrix(a, "lu")
192+
dtype = a.dtype
193+
194+
def impl(
195+
a: np.ndarray,
196+
permute_l: bool,
197+
check_finite: bool,
198+
p_indices: bool,
199+
overwrite_a: bool,
200+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
201+
perm, L, U = _lu_factor_to_lu(a, dtype, overwrite_a)
202+
P = np.eye(a.shape[-1], dtype=dtype)[perm]
203+
204+
return P, L, U
205+
206+
return impl
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
from numba.core.extending import overload
5+
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
6+
from scipy import linalg
7+
8+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
9+
_LAPACK,
10+
_get_underlying_float,
11+
int_ptr_to_val,
12+
val_to_int_ptr,
13+
)
14+
from pytensor.link.numba.dispatch.linalg.utils import (
15+
_check_scipy_linalg_matrix,
16+
)
17+
18+
19+
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
20+
"""
21+
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
22+
returns an info code with diagnostic information.
23+
"""
24+
(getrf,) = linalg.get_lapack_funcs("getrf", (A,))
25+
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
26+
27+
return A_copy, ipiv, info
28+
29+
30+
@overload(_getrf)
31+
def getrf_impl(
32+
A: np.ndarray, overwrite_a: bool = False
33+
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
34+
ensure_lapack()
35+
_check_scipy_linalg_matrix(A, "getrf")
36+
dtype = A.dtype
37+
w_type = _get_underlying_float(dtype)
38+
numba_getrf = _LAPACK().numba_xgetrf(dtype)
39+
40+
def impl(
41+
A: np.ndarray, overwrite_a: bool = False
42+
) -> tuple[np.ndarray, np.ndarray, int]:
43+
_M, _N = np.int32(A.shape[-2:]) # type: ignore
44+
45+
if overwrite_a and A.flags.f_contiguous:
46+
A_copy = A
47+
else:
48+
A_copy = _copy_to_fortran_order(A)
49+
50+
M = val_to_int_ptr(_M) # type: ignore
51+
N = val_to_int_ptr(_N) # type: ignore
52+
LDA = val_to_int_ptr(_M) # type: ignore
53+
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
54+
INFO = val_to_int_ptr(0)
55+
56+
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
57+
58+
return A_copy, IPIV, int_ptr_to_val(INFO)
59+
60+
return impl
61+
62+
63+
def _lu_factor(A: np.ndarray, overwrite_a: bool = False):
64+
"""
65+
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
66+
Pytensor.
67+
"""
68+
return linalg.lu_factor(A, overwrite_a=overwrite_a)
69+
70+
71+
@overload(_lu_factor)
72+
def lu_factor_impl(
73+
A: np.ndarray, overwrite_a: bool = False
74+
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
75+
ensure_lapack()
76+
_check_scipy_linalg_matrix(A, "lu_factor")
77+
78+
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
79+
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
80+
IPIV -= 1 # LAPACK uses 1-based indexing, convert to 0-based
81+
82+
if INFO != 0:
83+
raise np.linalg.LinAlgError("LU decomposition failed")
84+
return A_copy, IPIV
85+
86+
return impl

0 commit comments

Comments
 (0)