Skip to content

Commit 848ce19

Browse files
committed
Detect cases where boolean_indexing_set_or_inc does not apply
1 parent 0fced9a commit 848ce19

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

pytensor/tensor/rewriting/jax.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ def boolean_indexing_set_or_inc(fgraph, node):
2020
"""
2121

2222
op = node.op
23-
x = node.inputs[0]
24-
y = node.inputs[1]
25-
cond = node.inputs[2]
23+
[x, y, cond] = node.inputs
24+
25+
# This rewrite only works when `y` is a scalar, so it can broadcast to the shape of x[cond]
26+
if y.type.ndim > 0:
27+
return
2628

2729
if not isinstance(cond, TensorVariable):
2830
return

tests/link/jax/test_subtensor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from pytensor.configdefaults import config
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.tensor import subtensor as at_subtensor
8-
from pytensor.tensor.rewriting.jax import boolean_indexing_sum
8+
from pytensor.tensor.rewriting.jax import (
9+
boolean_indexing_set_or_inc,
10+
boolean_indexing_sum,
11+
)
912
from tests.link.jax.test_basic import compare_jax_and_py
1013

1114

@@ -216,7 +219,7 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
216219
217220
This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
218221
replaced with an equivalent `Switch` `Op`, using the
219-
`jax_boolean_indexing_set_of_inc` rewrite.
222+
`boolean_indexing_set_of_inc` rewrite.
220223
221224
JAX forces users to re-express this logic manually, so this is an
222225
improvement over its user interface.
@@ -237,3 +240,12 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
237240
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
238241
out_fg = FunctionGraph([x_at], [out_at])
239242
compare_jax_and_py(out_fg, [x_np])
243+
244+
245+
def test_boolean_indexing_set_or_inc_not_applicable():
246+
"""Test that `boolean_indexing_set_or_inc` does not return an invalid replacement in cases where it doesn't apply."""
247+
x = at.vector("x")
248+
mask = at.as_tensor(x) > 0
249+
out = at_subtensor.set_subtensor(x[mask], [0, 1, 2])
250+
fg = FunctionGraph([x], [out])
251+
assert boolean_indexing_set_or_inc.transform(fg, fg.outputs[0].owner) is None

0 commit comments

Comments
 (0)