Skip to content

Commit 4f7d709

Browse files
committed
Simplify local_[mul|div]_switch_sink and fix downcasting bug
1 parent 4d0aa3f commit 4f7d709

File tree

2 files changed

+93
-115
lines changed

2 files changed

+93
-115
lines changed

pytensor/tensor/rewriting/math.py

+69-114
Original file line numberDiff line numberDiff line change
@@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node):
621621
part of the graph.
622622
623623
"""
624-
for idx, i in enumerate(node.inputs):
625-
if i.owner and i.owner.op == switch:
626-
switch_node = i.owner
627-
try:
628-
if (
629-
get_underlying_scalar_constant_value(
630-
switch_node.inputs[1], only_process_constants=True
631-
)
632-
== 0.0
633-
):
634-
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
635-
fmul = mul(*([*listmul, switch_node.inputs[2]]))
636-
637-
# Copy over stacktrace for elementwise multiplication op
638-
# from previous elementwise multiplication op.
639-
# An error in the multiplication (e.g. errors due to
640-
# inconsistent shapes), will point to the
641-
# multiplication op.
642-
copy_stack_trace(node.outputs, fmul)
643-
644-
fct = [switch(switch_node.inputs[0], 0, fmul)]
645-
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
646-
647-
# Copy over stacktrace for switch op from both previous
648-
# elementwise multiplication op and previous switch op,
649-
# because an error in this part can be caused by either
650-
# of the two previous ops.
651-
copy_stack_trace(node.outputs + switch_node.outputs, fct)
652-
return fct
653-
except NotScalarConstantError:
654-
pass
655-
try:
656-
if (
657-
get_underlying_scalar_constant_value(
658-
switch_node.inputs[2], only_process_constants=True
659-
)
660-
== 0.0
661-
):
662-
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
663-
fmul = mul(*([*listmul, switch_node.inputs[1]]))
664-
# Copy over stacktrace for elementwise multiplication op
665-
# from previous elementwise multiplication op.
666-
# An error in the multiplication (e.g. errors due to
667-
# inconsistent shapes), will point to the
668-
# multiplication op.
669-
copy_stack_trace(node.outputs, fmul)
670-
671-
fct = [switch(switch_node.inputs[0], fmul, 0)]
672-
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
673-
674-
# Copy over stacktrace for switch op from both previous
675-
# elementwise multiplication op and previous switch op,
676-
# because an error in this part can be caused by either
677-
# of the two previous ops.
678-
copy_stack_trace(node.outputs + switch_node.outputs, fct)
679-
return fct
680-
except NotScalarConstantError:
681-
pass
682-
return False
624+
for mul_inp_idx, mul_inp in enumerate(node.inputs):
625+
if mul_inp.owner and mul_inp.owner.op == switch:
626+
switch_node = mul_inp.owner
627+
# Look for a zero as the first or second branch of the switch
628+
for branch in range(2):
629+
zero_switch_input = switch_node.inputs[1 + branch]
630+
if not get_unique_constant_value(zero_switch_input) == 0.0:
631+
continue
632+
633+
switch_cond = switch_node.inputs[0]
634+
other_switch_input = switch_node.inputs[1 + (1 - branch)]
635+
636+
listmul = list(node.inputs)
637+
listmul[mul_inp_idx] = other_switch_input
638+
fmul = mul(*listmul)
639+
640+
# Copy over stacktrace for elementwise multiplication op
641+
# from previous elementwise multiplication op.
642+
# An error in the multiplication (e.g. errors due to
643+
# inconsistent shapes), will point to the
644+
# multiplication op.
645+
copy_stack_trace(node.outputs, fmul)
646+
647+
if branch == 0:
648+
fct = switch(switch_cond, zero_switch_input, fmul)
649+
else:
650+
fct = switch(switch_cond, fmul, zero_switch_input)
651+
652+
# Tell debug_mode than the output is correct, even if nan disappear
653+
fct.tag.values_eq_approx = values_eq_approx_remove_nan
654+
655+
# Copy over stacktrace for switch op from both previous
656+
# elementwise multiplication op and previous switch op,
657+
# because an error in this part can be caused by either
658+
# of the two previous ops.
659+
copy_stack_trace(node.outputs + switch_node.outputs, fct)
660+
return [fct]
683661

684662

685663
@register_canonicalize
@@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node):
699677
See `local_mul_switch_sink` for more details.
700678
701679
"""
702-
op = node.op
703-
if node.inputs[0].owner and node.inputs[0].owner.op == switch:
704-
switch_node = node.inputs[0].owner
705-
try:
706-
if (
707-
get_underlying_scalar_constant_value(
708-
switch_node.inputs[1], only_process_constants=True
709-
)
710-
== 0.0
711-
):
712-
fdiv = op(switch_node.inputs[2], node.inputs[1])
713-
# Copy over stacktrace for elementwise division op
714-
# from previous elementwise multiplication op.
715-
# An error in the division (e.g. errors due to
716-
# inconsistent shapes or division by zero),
717-
# will point to the new division op.
718-
copy_stack_trace(node.outputs, fdiv)
719-
720-
fct = [switch(switch_node.inputs[0], 0, fdiv)]
721-
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
722-
723-
# Copy over stacktrace for switch op from both previous
724-
# elementwise division op and previous switch op,
725-
# because an error in this part can be caused by either
726-
# of the two previous ops.
727-
copy_stack_trace(node.outputs + switch_node.outputs, fct)
728-
return fct
729-
except NotScalarConstantError:
730-
pass
731-
try:
732-
if (
733-
get_underlying_scalar_constant_value(
734-
switch_node.inputs[2], only_process_constants=True
735-
)
736-
== 0.0
737-
):
738-
fdiv = op(switch_node.inputs[1], node.inputs[1])
739-
# Copy over stacktrace for elementwise division op
740-
# from previous elementwise multiplication op.
741-
# An error in the division (e.g. errors due to
742-
# inconsistent shapes or division by zero),
743-
# will point to the new division op.
744-
copy_stack_trace(node.outputs, fdiv)
745-
746-
fct = [switch(switch_node.inputs[0], fdiv, 0)]
747-
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
680+
num, denom = node.inputs
748681

749-
# Copy over stacktrace for switch op from both previous
750-
# elementwise division op and previous switch op,
751-
# because an error in this part can be caused by either
752-
# of the two previous ops.
753-
copy_stack_trace(node.outputs + switch_node.outputs, fct)
754-
return fct
755-
except NotScalarConstantError:
756-
pass
757-
return False
682+
if num.owner and num.owner.op == switch:
683+
switch_node = num.owner
684+
# Look for a zero as the first or second branch of the switch
685+
for branch in range(2):
686+
zero_switch_input = switch_node.inputs[1 + branch]
687+
if not get_unique_constant_value(zero_switch_input) == 0.0:
688+
continue
689+
690+
switch_cond = switch_node.inputs[0]
691+
other_switch_input = switch_node.inputs[1 + (1 - branch)]
692+
693+
fdiv = node.op(other_switch_input, denom)
694+
695+
# Copy over stacktrace for elementwise division op
696+
# from previous elementwise multiplication op.
697+
# An error in the division (e.g. errors due to
698+
# inconsistent shapes or division by zero),
699+
# will point to the new division op.
700+
copy_stack_trace(node.outputs, fdiv)
701+
702+
fct = switch(switch_cond, zero_switch_input, fdiv)
703+
704+
# Tell debug_mode than the output is correct, even if nan disappear
705+
fct.tag.values_eq_approx = values_eq_approx_remove_nan
706+
707+
# Copy over stacktrace for switch op from both previous
708+
# elementwise division op and previous switch op,
709+
# because an error in this part can be caused by either
710+
# of the two previous ops.
711+
copy_stack_trace(node.outputs + switch_node.outputs, fct)
712+
return [fct]
758713

759714

760715
class AlgebraicCanonizer(NodeRewriter):

tests/tensor/rewriting/test_math.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,11 @@
9797
from pytensor.tensor.rewriting.math import (
9898
compute_mul,
9999
is_1pexp,
100+
local_div_switch_sink,
100101
local_grad_log_erfc_neg,
101102
local_greedy_distributor,
102103
local_mul_canonizer,
104+
local_mul_switch_sink,
103105
local_reduce_chain,
104106
local_sum_prod_of_mul_or_div,
105107
mul_canonizer,
@@ -2115,7 +2117,6 @@ def test_local_mul_switch_sink(self):
21152117
f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode)
21162118
assert f(5) == 1, f(5)
21172119

2118-
@pytest.mark.slow
21192120
def test_local_div_switch_sink(self):
21202121
c = dscalar()
21212122
idx = 0
@@ -2149,6 +2150,28 @@ def test_local_div_switch_sink(self):
21492150
].size
21502151
idx += 1
21512152

2153+
@pytest.mark.parametrize(
2154+
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
2155+
)
2156+
def test_local_mul_div_switch_sink_cast(self, op, rewrite):
2157+
"""Check that we don't downcast during the rewrite.
2158+
2159+
Regression test for: https://github.com/pymc-devs/pytensor/issues/1037
2160+
"""
2161+
cond = scalar("cond", dtype="bool")
2162+
# The zero branch upcasts the output, so we can't ignore its dtype
2163+
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
2164+
other_branch = scalar("other_branch", dtype="float32")
2165+
outer_var = scalar("mul_var", dtype="bool")
2166+
2167+
out = op(switch(cond, zero_branch, other_branch), outer_var)
2168+
fgraph = FunctionGraph(outputs=[out], clone=False)
2169+
[new_out] = rewrite.transform(fgraph, out.owner)
2170+
assert new_out.type.dtype == out.type.dtype
2171+
2172+
expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
2173+
assert equal_computations([new_out], [expected_out])
2174+
21522175

21532176
@pytest.mark.skipif(
21542177
config.cxx == "",

0 commit comments

Comments
 (0)