Skip to content

Commit 462191b

Browse files
overloads for scipy.linalg.schur, .qz, .ordqz, .solve_discrete_lyapunov, and .solve_continuous_lyapunov
1 parent b6939a0 commit 462191b

File tree

6 files changed

+1442
-0
lines changed

6 files changed

+1442
-0
lines changed

numba_scipy/linalg/LAPACK.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from numba.extending import get_cython_function_address
2+
from numba.np.linalg import ensure_lapack, _blas_kinds
3+
import ctypes
4+
5+
_PTR = ctypes.POINTER
6+
7+
_dbl = ctypes.c_double
8+
_float = ctypes.c_float
9+
_char = ctypes.c_char
10+
_int = ctypes.c_int
11+
12+
_ptr_float = _PTR(_float)
13+
_ptr_dbl = _PTR(_dbl)
14+
_ptr_char = _PTR(_char)
15+
_ptr_int = _PTR(_int)
16+
17+
18+
def _get_float_pointer_for_dtype(blas_dtype):
19+
if blas_dtype in ['s', 'c']:
20+
return _ptr_float
21+
elif blas_dtype in ['d', 'z']:
22+
return _ptr_dbl
23+
24+
25+
class _LAPACK:
26+
"""
27+
Functions to return type signatures for wrapped
28+
LAPACK functions.
29+
"""
30+
31+
def __init__(self):
32+
ensure_lapack()
33+
34+
@classmethod
35+
def test_blas_kinds(cls, dtype):
36+
return _blas_kinds[dtype]
37+
38+
@classmethod
39+
def numba_rgees(cls, dtype):
40+
d = _blas_kinds[dtype]
41+
func_name = f'{d}gees'
42+
float_pointer = _get_float_pointer_for_dtype(d)
43+
addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name)
44+
functype = ctypes.CFUNCTYPE(None,
45+
_ptr_int, # JOBVS
46+
_ptr_int, # SORT
47+
_ptr_int, # SELECT
48+
_ptr_int, # N
49+
float_pointer, # A
50+
_ptr_int, # LDA
51+
_ptr_int, # SDIM
52+
float_pointer, # WR
53+
float_pointer, # WI
54+
float_pointer, # VS
55+
_ptr_int, # LDVS
56+
float_pointer, # WORK
57+
_ptr_int, # LWORK
58+
_ptr_int, # BWORK
59+
_ptr_int) # INFO
60+
return functype(addr)
61+
62+
@classmethod
63+
def numba_cgees(cls, dtype):
64+
d = _blas_kinds[dtype]
65+
func_name = f'{d}gees'
66+
float_pointer = _get_float_pointer_for_dtype(d)
67+
addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name)
68+
functype = ctypes.CFUNCTYPE(None,
69+
_ptr_int, # JOBVS
70+
_ptr_int, # SORT
71+
_ptr_int, # SELECT
72+
_ptr_int, # N
73+
float_pointer, # A
74+
_ptr_int, # LDA
75+
_ptr_int, # SDIM
76+
float_pointer, # W
77+
float_pointer, # VS
78+
_ptr_int, # LDVS
79+
float_pointer, # WORK
80+
_ptr_int, # LWORK
81+
float_pointer, # RWORK
82+
_ptr_int, # BWORK
83+
_ptr_int) # INFO
84+
return functype(addr)
85+
86+
@classmethod
87+
def numba_rgges(cls, dtype):
88+
d = _blas_kinds[dtype]
89+
func_name = f'{d}gges'
90+
float_pointer = _get_float_pointer_for_dtype(d)
91+
addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name)
92+
93+
functype = ctypes.CFUNCTYPE(None,
94+
_ptr_int, # JOBVSL
95+
_ptr_int, # JOBVSR
96+
_ptr_int, # SORT
97+
_ptr_int, # SELCTG
98+
_ptr_int, # N
99+
float_pointer, # A
100+
_ptr_int, # LDA
101+
float_pointer, # B
102+
_ptr_int, # LDB
103+
_ptr_int, # SDIM
104+
float_pointer, # ALPHAR
105+
float_pointer, # ALPHAI
106+
float_pointer, # BETA
107+
float_pointer, # VSL
108+
_ptr_int, # LDVSL
109+
float_pointer, # VSR
110+
_ptr_int, # LDVSR
111+
float_pointer, # WORK
112+
_ptr_int, # LWORK
113+
_ptr_int, # BWORK
114+
_ptr_int) # INFO
115+
return functype(addr)
116+
117+
@classmethod
118+
def numba_cgges(cls, dtype):
119+
d = _blas_kinds[dtype]
120+
func_name = f'{d}gges'
121+
float_pointer = _get_float_pointer_for_dtype(d)
122+
addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name)
123+
124+
functype = ctypes.CFUNCTYPE(None,
125+
_ptr_int, # JOBVSL
126+
_ptr_int, # JOBVSR
127+
_ptr_int, # SORT
128+
_ptr_int, # SELCTG
129+
_ptr_int, # N
130+
float_pointer, # A, complex
131+
_ptr_int, # LDA
132+
float_pointer, # B, complex
133+
_ptr_int, # LDB
134+
_ptr_int, # SDIM
135+
float_pointer, # ALPHA, complex
136+
float_pointer, # BETA, complex
137+
float_pointer, # VSL, complex
138+
_ptr_int, # LDVSL
139+
float_pointer, # VSR, complex
140+
_ptr_int, # LDVSR
141+
float_pointer, # WORK, complex
142+
_ptr_int, # LWORK
143+
float_pointer, # RWORK
144+
_ptr_int, # BWORK
145+
_ptr_int) # INFO
146+
return functype(addr)
147+
148+
@classmethod
149+
def numba_rtgsen(cls, dtype):
150+
d = _blas_kinds[dtype]
151+
func_name = f'{d}tgsen'
152+
float_pointer = _get_float_pointer_for_dtype(d)
153+
addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name)
154+
155+
functype = ctypes.CFUNCTYPE(None,
156+
_ptr_int, # IJOB
157+
_ptr_int, # WANTQ
158+
_ptr_int, # WANTZ
159+
_ptr_int, # SELECT
160+
_ptr_int, # N
161+
float_pointer, # A
162+
_ptr_int, # LDA
163+
float_pointer, # B
164+
_ptr_int, # LDB
165+
float_pointer, # ALPHAR
166+
float_pointer, # ALPHAI
167+
float_pointer, # BETA
168+
float_pointer, # Q
169+
_ptr_int, # LDQ
170+
float_pointer, # Z
171+
_ptr_int, # LDZ
172+
_ptr_int, # M
173+
float_pointer, # PL
174+
float_pointer, # PR
175+
float_pointer, # DIF
176+
float_pointer, # WORK
177+
_ptr_int, # LWORK
178+
_ptr_int, # IWORK
179+
_ptr_int, # LIWORK
180+
_ptr_int) # INFO
181+
return functype(addr)
182+
183+
@classmethod
184+
def numba_ctgsen(cls, dtype):
185+
d = _blas_kinds[dtype]
186+
func_name = f'{d}tgsen'
187+
float_pointer = _get_float_pointer_for_dtype(d)
188+
addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name)
189+
190+
functype = ctypes.CFUNCTYPE(None,
191+
_ptr_int, # IJOB
192+
_ptr_int, # WANTQ
193+
_ptr_int, # WANTZ
194+
_ptr_int, # SELECT
195+
_ptr_int, # N
196+
float_pointer, # A
197+
_ptr_int, # LDA
198+
float_pointer, # B
199+
_ptr_int, # LDB
200+
float_pointer, # ALPHA
201+
float_pointer, # BETA
202+
float_pointer, # Q
203+
_ptr_int, # LDQ
204+
float_pointer, # Z
205+
_ptr_int, # LDZ
206+
_ptr_int, # M
207+
float_pointer, # PL
208+
float_pointer, # PR
209+
float_pointer, # DIF
210+
float_pointer, # WORK
211+
_ptr_int, # LWORK
212+
_ptr_int, # IWORK
213+
_ptr_int, # LIWORK
214+
_ptr_int) # INFO
215+
return functype(addr)
216+
217+
@classmethod
218+
def numba_xtrsyl(cls, dtype):
219+
d = _blas_kinds[dtype]
220+
func_name = f'{d}trsyl'
221+
float_pointer = _get_float_pointer_for_dtype(d)
222+
addr = get_cython_function_address('scipy.linalg.cython_lapack', func_name)
223+
224+
functype = ctypes.CFUNCTYPE(None,
225+
_ptr_int, # TRANA
226+
_ptr_int, # TRANB
227+
_ptr_int, # ISGN
228+
_ptr_int, # M
229+
_ptr_int, # N
230+
float_pointer, # A
231+
_ptr_int, # LDA
232+
float_pointer, # B
233+
_ptr_int, # LDB
234+
float_pointer, # C
235+
_ptr_int, # LDC
236+
float_pointer, # SCALE
237+
_ptr_int) # INFO
238+
return functype(addr)

numba_scipy/linalg/__init__.py

Whitespace-only changes.

numba_scipy/linalg/intrinsics.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from numba.core import types, cgutils
2+
from numba.extending import intrinsic
3+
4+
5+
@intrinsic
6+
def val_to_dptr(typingctx, data):
7+
def impl(context, builder, signature, args):
8+
ptr = cgutils.alloca_once_value(builder, args[0])
9+
return ptr
10+
11+
sig = types.CPointer(types.float64)(types.float64)
12+
return sig, impl
13+
14+
15+
@intrinsic
16+
def val_to_zptr(typingctx, data):
17+
def impl(context, builder, signature, args):
18+
ptr = cgutils.alloca_once_value(builder, args[0])
19+
return ptr
20+
21+
sig = types.CPointer(types.complex128)(types.complex128)
22+
return sig, impl
23+
24+
25+
@intrinsic
26+
def val_to_sptr(typingctx, data):
27+
def impl(context, builder, signature, args):
28+
ptr = cgutils.alloca_once_value(builder, args[0])
29+
return ptr
30+
31+
sig = types.CPointer(types.float32)(types.float32)
32+
return sig, impl
33+
34+
35+
@intrinsic
36+
def val_to_int_ptr(typingctx, data):
37+
def impl(context, builder, signature, args):
38+
ptr = cgutils.alloca_once_value(builder, args[0])
39+
return ptr
40+
41+
sig = types.CPointer(types.int32)(types.int32)
42+
return sig, impl
43+
44+
45+
@intrinsic
46+
def int_ptr_to_val(typingctx, data):
47+
def impl(context, builder, signature, args):
48+
val = builder.load(args[0])
49+
return val
50+
51+
sig = types.int32(types.CPointer(types.int32))
52+
return sig, impl
53+
54+
55+
@intrinsic
56+
def dptr_to_val(typingctx, data):
57+
def impl(context, builder, signature, args):
58+
val = builder.load(args[0])
59+
return val
60+
61+
sig = types.float64(types.CPointer(types.float64))
62+
return sig, impl
63+
64+
65+
@intrinsic
66+
def sptr_to_val(typingctx, data):
67+
def impl(context, builder, signature, args):
68+
val = builder.load(args[0])
69+
return val
70+
71+
sig = types.float32(types.CPointer(types.float32))
72+
return sig, impl

0 commit comments

Comments
 (0)