Skip to content

Commit bbe663d

Browse files
jessegrabowskiricardoV94
authored andcommitted
Implement numba dispatch for all linalg.solve modes
1 parent 8e5e8a4 commit bbe663d

File tree

7 files changed

+1756
-357
lines changed

7 files changed

+1756
-357
lines changed
+392
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
import ctypes
2+
3+
import numpy as np
4+
from numba.core import cgutils, types
5+
from numba.core.extending import get_cython_function_address, intrinsic
6+
from numba.np.linalg import ensure_lapack, get_blas_kind
7+
8+
9+
_PTR = ctypes.POINTER
10+
11+
_dbl = ctypes.c_double
12+
_float = ctypes.c_float
13+
_char = ctypes.c_char
14+
_int = ctypes.c_int
15+
16+
_ptr_float = _PTR(_float)
17+
_ptr_dbl = _PTR(_dbl)
18+
_ptr_char = _PTR(_char)
19+
_ptr_int = _PTR(_int)
20+
21+
22+
def _get_lapack_ptr_and_ptr_type(dtype, name):
23+
d = get_blas_kind(dtype)
24+
func_name = f"{d}{name}"
25+
float_pointer = _get_float_pointer_for_dtype(d)
26+
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
27+
28+
return lapack_ptr, float_pointer
29+
30+
31+
def _get_underlying_float(dtype):
32+
s_dtype = str(dtype)
33+
out_type = s_dtype
34+
if s_dtype == "complex64":
35+
out_type = "float32"
36+
elif s_dtype == "complex128":
37+
out_type = "float64"
38+
39+
return np.dtype(out_type)
40+
41+
42+
def _get_float_pointer_for_dtype(blas_dtype):
43+
if blas_dtype in ["s", "c"]:
44+
return _ptr_float
45+
elif blas_dtype in ["d", "z"]:
46+
return _ptr_dbl
47+
48+
49+
def _get_output_ctype(dtype):
50+
s_dtype = str(dtype)
51+
if s_dtype in ["float32", "complex64"]:
52+
return _float
53+
elif s_dtype in ["float64", "complex128"]:
54+
return _dbl
55+
56+
57+
@intrinsic
58+
def sptr_to_val(typingctx, data):
59+
def impl(context, builder, signature, args):
60+
val = builder.load(args[0])
61+
return val
62+
63+
sig = types.float32(types.CPointer(types.float32))
64+
return sig, impl
65+
66+
67+
@intrinsic
68+
def dptr_to_val(typingctx, data):
69+
def impl(context, builder, signature, args):
70+
val = builder.load(args[0])
71+
return val
72+
73+
sig = types.float64(types.CPointer(types.float64))
74+
return sig, impl
75+
76+
77+
@intrinsic
78+
def int_ptr_to_val(typingctx, data):
79+
def impl(context, builder, signature, args):
80+
val = builder.load(args[0])
81+
return val
82+
83+
sig = types.int32(types.CPointer(types.int32))
84+
return sig, impl
85+
86+
87+
@intrinsic
88+
def val_to_int_ptr(typingctx, data):
89+
def impl(context, builder, signature, args):
90+
ptr = cgutils.alloca_once_value(builder, args[0])
91+
return ptr
92+
93+
sig = types.CPointer(types.int32)(types.int32)
94+
return sig, impl
95+
96+
97+
@intrinsic
98+
def val_to_sptr(typingctx, data):
99+
def impl(context, builder, signature, args):
100+
ptr = cgutils.alloca_once_value(builder, args[0])
101+
return ptr
102+
103+
sig = types.CPointer(types.float32)(types.float32)
104+
return sig, impl
105+
106+
107+
@intrinsic
108+
def val_to_zptr(typingctx, data):
109+
def impl(context, builder, signature, args):
110+
ptr = cgutils.alloca_once_value(builder, args[0])
111+
return ptr
112+
113+
sig = types.CPointer(types.complex128)(types.complex128)
114+
return sig, impl
115+
116+
117+
@intrinsic
118+
def val_to_dptr(typingctx, data):
119+
def impl(context, builder, signature, args):
120+
ptr = cgutils.alloca_once_value(builder, args[0])
121+
return ptr
122+
123+
sig = types.CPointer(types.float64)(types.float64)
124+
return sig, impl
125+
126+
127+
class _LAPACK:
128+
"""
129+
Functions to return type signatures for wrapped LAPACK functions.
130+
131+
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
132+
"""
133+
134+
def __init__(self):
135+
ensure_lapack()
136+
137+
@classmethod
138+
def numba_xtrtrs(cls, dtype):
139+
"""
140+
Solve a triangular system of equations of the form A @ X = B or A.T @ X = B.
141+
142+
Called by scipy.linalg.solve_triangular
143+
"""
144+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
145+
146+
functype = ctypes.CFUNCTYPE(
147+
None,
148+
_ptr_int, # UPLO
149+
_ptr_int, # TRANS
150+
_ptr_int, # DIAG
151+
_ptr_int, # N
152+
_ptr_int, # NRHS
153+
float_pointer, # A
154+
_ptr_int, # LDA
155+
float_pointer, # B
156+
_ptr_int, # LDB
157+
_ptr_int, # INFO
158+
)
159+
160+
return functype(lapack_ptr)
161+
162+
@classmethod
163+
def numba_xpotrf(cls, dtype):
164+
"""
165+
Compute the Cholesky factorization of a real symmetric positive definite matrix.
166+
167+
Called by scipy.linalg.cholesky
168+
"""
169+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
170+
functype = ctypes.CFUNCTYPE(
171+
None,
172+
_ptr_int, # UPLO,
173+
_ptr_int, # N
174+
float_pointer, # A
175+
_ptr_int, # LDA
176+
_ptr_int, # INFO
177+
)
178+
return functype(lapack_ptr)
179+
180+
@classmethod
181+
def numba_xpotrs(cls, dtype):
182+
"""
183+
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
184+
factorization computed by numba_potrf.
185+
186+
Called by scipy.linalg.cho_solve
187+
"""
188+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrs")
189+
functype = ctypes.CFUNCTYPE(
190+
None,
191+
_ptr_int, # UPLO
192+
_ptr_int, # N
193+
_ptr_int, # NRHS
194+
float_pointer, # A
195+
_ptr_int, # LDA
196+
float_pointer, # B
197+
_ptr_int, # LDB
198+
_ptr_int, # INFO
199+
)
200+
return functype(lapack_ptr)
201+
202+
@classmethod
203+
def numba_xlange(cls, dtype):
204+
"""
205+
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
206+
a general M-by-N matrix A.
207+
208+
Called by scipy.linalg.solve
209+
"""
210+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange")
211+
output_ctype = _get_output_ctype(dtype)
212+
functype = ctypes.CFUNCTYPE(
213+
output_ctype, # Output
214+
_ptr_int, # NORM
215+
_ptr_int, # M
216+
_ptr_int, # N
217+
float_pointer, # A
218+
_ptr_int, # LDA
219+
float_pointer, # WORK
220+
)
221+
return functype(lapack_ptr)
222+
223+
@classmethod
224+
def numba_xlamch(cls, dtype):
225+
"""
226+
Determine machine precision for floating point arithmetic.
227+
"""
228+
229+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch")
230+
output_dtype = _get_output_ctype(dtype)
231+
functype = ctypes.CFUNCTYPE(
232+
output_dtype, # Output
233+
_ptr_int, # CMACH
234+
)
235+
return functype(lapack_ptr)
236+
237+
@classmethod
238+
def numba_xgecon(cls, dtype):
239+
"""
240+
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
241+
242+
Called by scipy.linalg.solve when assume_a == "gen"
243+
"""
244+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon")
245+
functype = ctypes.CFUNCTYPE(
246+
None,
247+
_ptr_int, # NORM
248+
_ptr_int, # N
249+
float_pointer, # A
250+
_ptr_int, # LDA
251+
float_pointer, # ANORM
252+
float_pointer, # RCOND
253+
float_pointer, # WORK
254+
_ptr_int, # IWORK
255+
_ptr_int, # INFO
256+
)
257+
return functype(lapack_ptr)
258+
259+
@classmethod
260+
def numba_xgetrf(cls, dtype):
261+
"""
262+
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
263+
264+
Called by scipy.linalg.lu_factor
265+
"""
266+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf")
267+
functype = ctypes.CFUNCTYPE(
268+
None,
269+
_ptr_int, # M
270+
_ptr_int, # N
271+
float_pointer, # A
272+
_ptr_int, # LDA
273+
_ptr_int, # IPIV
274+
_ptr_int, # INFO
275+
)
276+
return functype(lapack_ptr)
277+
278+
@classmethod
279+
def numba_xgetrs(cls, dtype):
280+
"""
281+
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
282+
factorization computed by GETRF.
283+
284+
Called by scipy.linalg.lu_solve
285+
"""
286+
...
287+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
288+
functype = ctypes.CFUNCTYPE(
289+
None,
290+
_ptr_int, # TRANS
291+
_ptr_int, # N
292+
_ptr_int, # NRHS
293+
float_pointer, # A
294+
_ptr_int, # LDA
295+
_ptr_int, # IPIV
296+
float_pointer, # B
297+
_ptr_int, # LDB
298+
_ptr_int, # INFO
299+
)
300+
return functype(lapack_ptr)
301+
302+
@classmethod
303+
def numba_xsysv(cls, dtype):
304+
"""
305+
Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method,
306+
factorizing A into LDL^T or UDU^T form, depending on the value of UPLO
307+
308+
Called by scipy.linalg.solve when assume_a == "sym"
309+
"""
310+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv")
311+
functype = ctypes.CFUNCTYPE(
312+
None,
313+
_ptr_int, # UPLO
314+
_ptr_int, # N
315+
_ptr_int, # NRHS
316+
float_pointer, # A
317+
_ptr_int, # LDA
318+
_ptr_int, # IPIV
319+
float_pointer, # B
320+
_ptr_int, # LDB
321+
float_pointer, # WORK
322+
_ptr_int, # LWORK
323+
_ptr_int, # INFO
324+
)
325+
return functype(lapack_ptr)
326+
327+
@classmethod
328+
def numba_xsycon(cls, dtype):
329+
"""
330+
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
331+
computed by xSYTRF.
332+
"""
333+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon")
334+
335+
functype = ctypes.CFUNCTYPE(
336+
None,
337+
_ptr_int, # UPLO
338+
_ptr_int, # N
339+
float_pointer, # A
340+
_ptr_int, # LDA
341+
_ptr_int, # IPIV
342+
float_pointer, # ANORM
343+
float_pointer, # RCOND
344+
float_pointer, # WORK
345+
_ptr_int, # IWORK
346+
_ptr_int, # INFO
347+
)
348+
return functype(lapack_ptr)
349+
350+
@classmethod
351+
def numba_xpocon(cls, dtype):
352+
"""
353+
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
354+
computed by potrf.
355+
356+
Called by scipy.linalg.solve when assume_a == "pos"
357+
"""
358+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon")
359+
functype = ctypes.CFUNCTYPE(
360+
None,
361+
_ptr_int, # UPLO
362+
_ptr_int, # N
363+
float_pointer, # A
364+
_ptr_int, # LDA
365+
float_pointer, # ANORM
366+
float_pointer, # RCOND
367+
float_pointer, # WORK
368+
_ptr_int, # IWORK
369+
_ptr_int, # INFO
370+
)
371+
return functype(lapack_ptr)
372+
373+
@classmethod
374+
def numba_xposv(cls, dtype):
375+
"""
376+
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
377+
factorization computed by potrf.
378+
"""
379+
380+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv")
381+
functype = ctypes.CFUNCTYPE(
382+
None,
383+
_ptr_int, # UPLO
384+
_ptr_int, # N
385+
_ptr_int, # NRHS
386+
float_pointer, # A
387+
_ptr_int, # LDA
388+
float_pointer, # B
389+
_ptr_int, # LDB
390+
_ptr_int, # INFO
391+
)
392+
return functype(lapack_ptr)

0 commit comments

Comments
 (0)