|
24 | 24 | Solve,
|
25 | 25 | SolveTriangular,
|
26 | 26 | )
|
| 27 | +from pytensor.tensor.type import complex_dtypes |
| 28 | + |
| 29 | + |
| 30 | +_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = ( |
| 31 | + "Complex dtype for {op} not supported in numba mode. " |
| 32 | + "If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor" |
| 33 | +) |
27 | 34 |
|
28 | 35 |
|
29 | 36 | @numba_basic.numba_njit(inline="always")
|
@@ -199,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
|
199 | 206 | b_ndim = op.b_ndim
|
200 | 207 |
|
201 | 208 | dtype = node.inputs[0].dtype
|
202 |
| - if str(dtype).startswith("complex"): |
| 209 | + if dtype in complex_dtypes: |
203 | 210 | raise NotImplementedError(
|
204 |
| - "Complex inputs not currently supported by solve_triangular in Numba mode" |
| 211 | + _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular") |
205 | 212 | )
|
206 | 213 |
|
207 | 214 | @numba_basic.numba_njit(inline="always")
|
@@ -299,10 +306,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
|
299 | 306 | on_error = op.on_error
|
300 | 307 |
|
301 | 308 | dtype = node.inputs[0].dtype
|
302 |
| - if str(dtype).startswith("complex"): |
303 |
| - raise NotImplementedError( |
304 |
| - "Complex inputs not currently supported by cholesky in Numba mode" |
305 |
| - ) |
| 309 | + if dtype in complex_dtypes: |
| 310 | + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) |
306 | 311 |
|
307 | 312 | @numba_basic.numba_njit(inline="always")
|
308 | 313 | def nb_cholesky(a):
|
@@ -1089,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs):
|
1089 | 1094 | transposed = False # TODO: Solve doesnt currently allow the transposed argument
|
1090 | 1095 |
|
1091 | 1096 | dtype = node.inputs[0].dtype
|
1092 |
| - if str(dtype).startswith("complex"): |
1093 |
| - raise NotImplementedError( |
1094 |
| - "Complex inputs not currently supported by solve in Numba mode" |
1095 |
| - ) |
| 1097 | + if dtype in complex_dtypes: |
| 1098 | + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) |
1096 | 1099 |
|
1097 | 1100 | if assume_a == "gen":
|
1098 | 1101 | solve_fn = _solve_gen
|
@@ -1206,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
|
1206 | 1209 | check_finite = op.check_finite
|
1207 | 1210 |
|
1208 | 1211 | dtype = node.inputs[0].dtype
|
1209 |
| - if str(dtype).startswith("complex"): |
1210 |
| - raise NotImplementedError( |
1211 |
| - "Complex inputs not currently supported by cho_solve in Numba mode" |
1212 |
| - ) |
| 1212 | + if dtype in complex_dtypes: |
| 1213 | + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) |
1213 | 1214 |
|
1214 | 1215 | @numba_basic.numba_njit(inline="always")
|
1215 | 1216 | def cho_solve(c, b):
|
|
0 commit comments