Skip to content

Commit 2c4a3e7

Browse files
committed
Tag rewrites that make shape assumptions
1 parent 5db0d83 commit 2c4a3e7

File tree

4 files changed

+27
-31
lines changed

4 files changed

+27
-31
lines changed

pytensor/configdefaults.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -682,25 +682,7 @@ def add_traceback_configvars():
682682

683683

684684
def add_experimental_configvars():
685-
config.add(
686-
"experimental__local_alloc_elemwise",
687-
"DEPRECATED: If True, enable the experimental"
688-
" optimization local_alloc_elemwise."
689-
" Generates error if not True. Use"
690-
" optimizer_excluding=local_alloc_elemwise"
691-
" to disable.",
692-
BoolParam(True),
693-
in_c_key=False,
694-
)
695-
696-
# False could make the graph faster but not as safe.
697-
config.add(
698-
"experimental__local_alloc_elemwise_assert",
699-
"When the local_alloc_elemwise is applied, add"
700-
" an assert to highlight shape errors.",
701-
BoolParam(True),
702-
in_c_key=False,
703-
)
685+
return
704686

705687

706688
def add_error_and_warning_configvars():

pytensor/tensor/rewriting/basic.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def local_scalar_tensor_scalar(fgraph, node):
256256
return [s]
257257

258258

259-
@register_specialize("local_alloc_elemwise")
259+
@register_specialize("shape_unsafe")
260260
@node_rewriter([Elemwise])
261261
def local_elemwise_alloc(fgraph, node):
262262
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
@@ -377,7 +377,7 @@ def dimshuffled_alloc(i):
377377
return ret
378378

379379

380-
@register_canonicalize
380+
@register_canonicalize("shape_unsafe")
381381
@node_rewriter([Elemwise])
382382
def local_fill_sink(fgraph, node):
383383
"""
@@ -428,8 +428,8 @@ def local_fill_sink(fgraph, node):
428428
return replacements
429429

430430

431-
@register_specialize
432-
@register_stabilize
431+
@register_specialize("shape_unsafe")
432+
@register_stabilize("shape_unsafe")
433433
@node_rewriter([fill])
434434
def local_fill_to_alloc(fgraph, node):
435435
r"""Remove `fill`\s or replace them with `Alloc`\s.
@@ -479,8 +479,8 @@ def local_fill_to_alloc(fgraph, node):
479479
)
480480

481481

482-
@register_canonicalize("fast_compile")
483-
@register_useless
482+
@register_canonicalize("fast_compile", "shape_unsafe")
483+
@register_useless("shape_unsafe")
484484
@node_rewriter([fill])
485485
def local_useless_fill(fgraph, node):
486486
"""fill(s,v) -> v
@@ -500,10 +500,10 @@ def local_useless_fill(fgraph, node):
500500
return [v]
501501

502502

503-
@register_specialize
504-
@register_stabilize
505-
@register_canonicalize
506-
@register_useless
503+
@register_specialize("shape_unsafe")
504+
@register_stabilize("shape_unsafe")
505+
@register_canonicalize("shape_unsafe")
506+
@register_useless("shape_unsafe")
507507
@node_rewriter([Alloc])
508508
def local_useless_alloc(fgraph, node):
509509
"""

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
11761176
local_mul_canonizer = AlgebraicCanonizer(
11771177
mul, true_div, reciprocal, mul_calculate, False
11781178
)
1179-
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
1179+
register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canonizer")
11801180

11811181

11821182
@register_canonicalize
@@ -2493,7 +2493,7 @@ def add_calculate(num, denum, aslist=False, out_type=None):
24932493
)
24942494

24952495

2496-
register_canonicalize(local_add_canonizer, name="local_add_canonizer")
2496+
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
24972497

24982498

24992499
def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0):

tests/tensor/rewriting/test_basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,3 +1933,17 @@ def test_misc(self):
19331933
x_val = np.random.random((1, 5)).astype(self.dtype)
19341934
exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val
19351935
assert np.array_equal(func(y_val, x_val), exp_res)
1936+
1937+
1938+
def test_shape_unsafe_tag():
1939+
mode = get_mode("FAST_RUN")
1940+
x = vector("x")
1941+
y = vector("y")
1942+
out = x * y / y
1943+
1944+
fn = function([x, y], out, mode=mode)
1945+
np.testing.assert_equal(fn([0, 1], [2, 3, 4]), [0, 1])
1946+
1947+
fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
1948+
with pytest.raises(ValueError):
1949+
fn([0, 1], [2, 3, 4]), [0, 1]

0 commit comments

Comments
 (0)