Skip to content

Commit e73258b

Browse files
committed
Fix bug in local_reduce_join rewrite.
The helper `apply_local_dimshuffle_lift` requires a FunctionGraph when elemwise inputs are involved.
1 parent e934ac7 commit e73258b

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

pytensor/tensor/rewriting/math.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1620,7 +1620,7 @@ def local_reduce_join(fgraph, node):
16201620
if not inp.type.broadcastable[join_axis]:
16211621
return None
16221622
# Most times inputs to join have an expand_dims, we eagerly clean up those here
1623-
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
1623+
new_input = apply_local_dimshuffle_lift(fgraph, inp.squeeze(join_axis))
16241624
new_inputs.append(new_input)
16251625

16261626
ret = Elemwise(node.op.scalar_op)(*new_inputs)

tests/tensor/rewriting/test_math.py

+19
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
local_mul_canonizer,
104104
local_mul_switch_sink,
105105
local_reduce_chain,
106+
local_reduce_join,
106107
local_sum_prod_of_mul_or_div,
107108
mul_canonizer,
108109
parse_mul_tree,
@@ -3415,6 +3416,24 @@ def test_not_supported_unequal_shapes(self):
34153416
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
34163417
)
34173418

3419+
def test_non_ds_inputs(self):
3420+
"""Make sure rewrite works when inputs to join are not the usual DimShuffle.
3421+
3422+
Sum{axis=1} [id A] <Vector(float64, shape=(3,))>
3423+
└─ Join [id B] <Matrix(float64, shape=(3, 3))>
3424+
├─ 1 [id C] <Scalar(int8, shape=())>
3425+
├─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(3, 1))>
3426+
├─ Sub [id E] <Matrix(float64, shape=(3, 1))>
3427+
└─ Sub [id F] <Matrix(float64, shape=(3, 1))>
3428+
"""
3429+
x = vector("x")
3430+
out = join(0, exp(x[None]), log(x[None])).sum(axis=0)
3431+
3432+
fg = FunctionGraph([x], [out], clone=False)
3433+
[rewritten_out] = local_reduce_join.transform(fg, out.owner)
3434+
expected_out = add(exp(x), log(x))
3435+
assert equal_computations([rewritten_out], [expected_out])
3436+
34183437

34193438
def test_local_useless_adds():
34203439
default_mode = get_default_mode()

0 commit comments

Comments
 (0)