Skip to content

Commit 7a5ea0f

Browse files
ricardoV94jessegrabowskiaseyboldt
committed
Adapt Numba vectorize iterator for RandomVariables
Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 18dcf62 commit 7a5ea0f

File tree

3 files changed

+325
-158
lines changed

3 files changed

+325
-158
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def numba_njit(*args, **kwargs):
6565
# Supress caching warnings
6666
warnings.filterwarnings(
6767
"ignore",
68-
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals',
68+
message='Cannot cache compiled function "(numba_funcified_fgraph|store_core_outputs)"',
6969
category=NumbaWarning,
7070
)
7171

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_jit_options,
2525
_vectorized,
2626
encode_literals,
27+
store_core_outputs,
2728
)
2829
from pytensor.link.utils import compile_function_src, get_name_for_object
2930
from pytensor.scalar.basic import (
@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
480481
**kwargs,
481482
)
482483

484+
nin = len(node.inputs)
485+
nout = len(node.outputs)
486+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
487+
483488
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
484489
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
485490
output_dtypes = tuple(out.type.dtype for out in node.outputs)
486491
inplace_pattern = tuple(op.inplace_pattern.items())
492+
core_output_shapes = tuple(() for _ in range(nout))
487493

488494
# numba doesn't support nested literals right now...
489495
input_bc_patterns_enc = encode_literals(input_bc_patterns)
@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
493499

494500
def elemwise_wrapper(*inputs):
495501
return _vectorized(
496-
scalar_op_fn,
502+
core_op_fn,
497503
input_bc_patterns_enc,
498504
output_bc_patterns_enc,
499505
output_dtypes_enc,
500506
inplace_pattern_enc,
507+
(), # constant_inputs
501508
inputs,
509+
core_output_shapes, # core_shapes
510+
None, # size
502511
)
503512

504513
# Pure python implementation, that will be used in tests

0 commit comments

Comments
 (0)