Skip to content

Commit ef22377

Browse files
committed
Fix bug when broadcasting branches in local_useless_switch rewrite
1 parent 5a47550 commit ef22377

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

pytensor/tensor/rewriting/basic.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -1023,18 +1023,15 @@ def local_useless_switch(fgraph, node):
10231023

10241024
# if left is right -> left
10251025
if equivalent_up_to_constant_casting(left, right):
1026-
if left.type.broadcastable == out_bcast:
1027-
out_dtype = node.outputs[0].type.dtype
1028-
if left.type.dtype != out_dtype:
1029-
left = cast(left, out_dtype)
1030-
copy_stack_trace(node.outputs + left, left)
1031-
# When not casting, the other inputs of the switch aren't needed in the traceback
1032-
return [left]
1026+
if left.type.broadcastable != out_bcast:
1027+
left, _ = broadcast_arrays(left, cond)
10331028

1034-
else:
1035-
ret = broadcast_arrays(left, cond)[0]
1036-
copy_stack_trace(node.outputs + left, ret)
1037-
return [ret]
1029+
out_dtype = node.outputs[0].type.dtype
1030+
if left.type.dtype != out_dtype:
1031+
left = cast(left, out_dtype)
1032+
1033+
copy_stack_trace(node.outputs + node.inputs, left)
1034+
return [left]
10381035

10391036
# This case happens with scan.
10401037
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)

tests/tensor/rewriting/test_basic.py

+19
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,25 @@ def test_broadcasting_3(self):
10891089
assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc)
10901090
assert not any(node.op == pt.switch for node in f.maker.fgraph.toposort())
10911091

1092+
def test_broadcasting_different_dtype(self):
1093+
cond = vector("x", dtype="bool")
1094+
float32_branch = as_tensor(np.array([0], dtype="float32"))
1095+
float64_branch = as_tensor(np.array([0], dtype="float64"))
1096+
1097+
out = pt.switch(cond, float32_branch, float64_branch)
1098+
expected_out = pt.alloc(float64_branch, cond.shape)
1099+
1100+
rewritten_out = rewrite_graph(
1101+
out, include=("canonicalize", "stabilize", "specialize")
1102+
)
1103+
assert equal_computations([rewritten_out], [expected_out])
1104+
1105+
out = pt.switch(cond, float64_branch, float32_branch)
1106+
rewritten_out = rewrite_graph(
1107+
out, include=("canonicalize", "stabilize", "specialize")
1108+
)
1109+
assert equal_computations([rewritten_out], [expected_out])
1110+
10921111

10931112
class TestLocalMergeSwitchSameCond:
10941113
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)