Skip to content

Commit 39704d1

Browse files
committed
Add call for issue in not implemented complex lapack routines
1 parent 8c97bb2 commit 39704d1

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

pytensor/link/numba/dispatch/slinalg.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
Solve,
2525
SolveTriangular,
2626
)
27+
from pytensor.tensor.type import complex_dtypes
28+
29+
30+
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
31+
"Complex dtype for {op} not supported in numba mode. "
32+
"If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
33+
)
2734

2835

2936
@numba_basic.numba_njit(inline="always")
@@ -199,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
199206
b_ndim = op.b_ndim
200207

201208
dtype = node.inputs[0].dtype
202-
if str(dtype).startswith("complex"):
209+
if dtype in complex_dtypes:
203210
raise NotImplementedError(
204-
"Complex inputs not currently supported by solve_triangular in Numba mode"
211+
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
205212
)
206213

207214
@numba_basic.numba_njit(inline="always")
@@ -299,10 +306,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
299306
on_error = op.on_error
300307

301308
dtype = node.inputs[0].dtype
302-
if str(dtype).startswith("complex"):
303-
raise NotImplementedError(
304-
"Complex inputs not currently supported by cholesky in Numba mode"
305-
)
309+
if dtype in complex_dtypes:
310+
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
306311

307312
@numba_basic.numba_njit(inline="always")
308313
def nb_cholesky(a):
@@ -1089,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs):
10891094
transposed = False # TODO: Solve doesnt currently allow the transposed argument
10901095

10911096
dtype = node.inputs[0].dtype
1092-
if str(dtype).startswith("complex"):
1093-
raise NotImplementedError(
1094-
"Complex inputs not currently supported by solve in Numba mode"
1095-
)
1097+
if dtype in complex_dtypes:
1098+
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
10961099

10971100
if assume_a == "gen":
10981101
solve_fn = _solve_gen
@@ -1206,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
12061209
check_finite = op.check_finite
12071210

12081211
dtype = node.inputs[0].dtype
1209-
if str(dtype).startswith("complex"):
1210-
raise NotImplementedError(
1211-
"Complex inputs not currently supported by cho_solve in Numba mode"
1212-
)
1212+
if dtype in complex_dtypes:
1213+
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
12131214

12141215
@numba_basic.numba_njit(inline="always")
12151216
def cho_solve(c, b):

0 commit comments

Comments
 (0)