Skip to content

Commit 5907874

Browse files
committed
.hacks
1 parent 17ef6a4 commit 5907874

File tree

6 files changed

+59
-13
lines changed

6 files changed

+59
-13
lines changed

pytensor/compile/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
507507
predefined_modes = {
508508
"FAST_COMPILE": FAST_COMPILE,
509509
"FAST_RUN": FAST_RUN,
510+
"OLD_FAST_RUN": Mode("cvm", "fast_run"),
510511
"JAX": JAX,
511512
"NUMBA": NUMBA,
512513
"PYTORCH": PYTORCH,

pytensor/link/numba/dispatch/basic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import scipy.special
1313
from llvmlite import ir
1414
from numba import types
15-
from numba.core.errors import NumbaWarning, TypingError
15+
from numba.core.errors import TypingError
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1717
from numba.extending import box, overload
1818

@@ -71,16 +71,16 @@ def numba_njit(*args, fastmath=None, **kwargs):
7171

7272
# Suppress cache warning for internal functions
7373
# We have to add an ansi escape code for optional bold text by numba
74-
warnings.filterwarnings(
75-
"ignore",
76-
message=(
77-
"(\x1b\\[1m)*" # ansi escape code for bold text
78-
"Cannot cache compiled function "
79-
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
80-
"as it uses dynamic globals"
81-
),
82-
category=NumbaWarning,
83-
)
74+
# warnings.filterwarnings(
75+
# "ignore",
76+
# message=(
77+
# "(\x1b\\[1m)*" # ansi escape code for bold text
78+
# "Cannot cache compiled function "
79+
# '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
80+
# "as it uses dynamic globals"
81+
# ),
82+
# category=NumbaWarning,
83+
# )
8484

8585
if len(args) > 0 and callable(args[0]):
8686
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
_jit_options,
1717
_vectorized,
1818
encode_literals,
19-
store_core_outputs,
2019
)
2120
from pytensor.link.utils import compile_function_src
2221
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
@@ -276,7 +275,12 @@ def numba_funcify_Elemwise(op, node, **kwargs):
276275

277276
nin = len(node.inputs)
278277
nout = len(node.outputs)
279-
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
278+
# core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
279+
if isinstance(op.scalar_op, Mul) and len(node.inputs) == 2:
280+
281+
@numba_njit
282+
def core_op_fn(x, y, out):
283+
out[...] = x * y
280284

281285
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
282286
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ def numba_funcify_Add(op, node, **kwargs):
196196

197197
@numba_funcify.register(Mul)
198198
def numba_funcify_Mul(op, node, **kwargs):
199+
if len(node.inputs) == 2:
200+
201+
@numba_basic.numba_njit
202+
def binary_mul(x, y):
203+
return x * y
204+
205+
return binary_mul
206+
199207
signature = create_numba_signature(node, force_scalar=True)
200208
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
201209

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AdvancedSubtensor1,
1414
IncSubtensor,
1515
Subtensor,
16+
get_idx_list,
1617
)
1718
from pytensor.tensor.type_other import NoneTypeT, SliceType
1819

@@ -95,6 +96,9 @@ def {function_name}({", ".join(input_names)}):
9596
return np.asarray(z)
9697
"""
9798

99+
print()
100+
node.dprint(depth=2, print_type=True)
101+
print("subtensor_def_src:", subtensor_def_src)
98102
func = compile_function_src(
99103
subtensor_def_src,
100104
function_name=function_name,
@@ -103,6 +107,25 @@ def {function_name}({", ".join(input_names)}):
103107
return numba_njit(func, boundscheck=True)
104108

105109

110+
@numba_funcify.register(Subtensor)
111+
def numba_funcify_subtensor_custom(op, node, **kwargs):
112+
idxs = get_idx_list(node.inputs, op.idx_list)
113+
114+
if (
115+
idxs
116+
and not isinstance(idxs[0], slice)
117+
and all(idx == slice(None) for idx in idxs[1:])
118+
):
119+
120+
@numba_njit
121+
def scalar_subtensor_leading_dim(x, idx):
122+
return x[idx]
123+
124+
return scalar_subtensor_leading_dim
125+
126+
return numba_funcify_default_subtensor(op, node, **kwargs)
127+
128+
106129
@numba_funcify.register(AdvancedSubtensor)
107130
@numba_funcify.register(AdvancedIncSubtensor)
108131
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
3535
on[...] = ton
3636
3737
"""
38+
if nin == 2 and nout == 1:
39+
40+
@numba_basic.numba_njit
41+
def store_core_outputs_2in1out(i0, i1, o0):
42+
t0 = core_op_fn(i0, i1)
43+
o0[...] = t0
44+
45+
return store_core_outputs_2in1out
46+
print(nin, nout)
47+
3848
inputs = [f"i{i}" for i in range(nin)]
3949
outputs = [f"o{i}" for i in range(nout)]
4050
inner_outputs = [f"t{output}" for output in outputs]

0 commit comments

Comments
 (0)