Skip to content

Commit 0fced9a

Browse files
committed
Detect cases where boolean_indexing_sum does not apply
1 parent 62f84f5 commit 0fced9a

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

pytensor/tensor/rewriting/jax.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
4848

4949
@node_rewriter([Sum])
5050
def boolean_indexing_sum(fgraph, node):
51-
"""Replace the sum of `AdvancedSubtensor` with boolean indexing.
51+
"""Replace the sum of `AdvancedSubtensor` with exclusively boolean indexing.
5252
5353
JAX cannot JIT-compile functions that use boolean indexing, but can compile
5454
those expressions that can be re-expressed using `jax.numpy.where`. This
@@ -61,21 +61,30 @@ def boolean_indexing_sum(fgraph, node):
6161
if not isinstance(operand, TensorVariable):
6262
return
6363

64+
# If it's not a scalar reduction, it couldn't have been a pure boolean mask
65+
if node.outputs[0].ndim != 0:
66+
return
67+
6468
if operand.owner is None:
6569
return
6670

6771
if not isinstance(operand.owner.op, AdvancedSubtensor):
6872
return
6973

70-
x = operand.owner.inputs[0]
71-
cond = operand.owner.inputs[1]
74+
# Get out if AdvancedSubtensor has more than a single indexing operation
75+
if len(operand.owner.inputs) > 2:
76+
return
77+
78+
[x, cond] = operand.owner.inputs
7279

7380
if not isinstance(cond, TensorVariable):
7481
return
7582

7683
if not cond.type.dtype == "bool":
7784
return
7885

86+
# Output must be a scalar, since pure boolean indexing returns a vector
87+
# No need to worry about axis
7988
out = at.sum(at.where(cond, x, 0))
8089
return out.owner.outputs
8190

tests/link/jax/test_subtensor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.configdefaults import config
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.tensor import subtensor as at_subtensor
8+
from pytensor.tensor.rewriting.jax import boolean_indexing_sum
89
from tests.link.jax.test_basic import compare_jax_and_py
910

1011

@@ -93,10 +94,22 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
9394
improvement over its user interface.
9495
9596
"""
96-
x_at = at.vector("x")
97+
x_at = at.matrix("x")
9798
out_at = x_at[x_at < 0].sum()
9899
out_fg = FunctionGraph([x_at], [out_at])
99-
compare_jax_and_py(out_fg, [np.arange(-5, 5).astype(config.floatX)])
100+
compare_jax_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)])
101+
102+
103+
def test_boolean_indexing_sum_not_applicable():
104+
"""Test that boolean_indexing_sum does not return an invalid replacement in cases where it doesn't apply."""
105+
x = at.matrix("x")
106+
out = x[x[:, 0] < 0, :].sum(axis=-1)
107+
fg = FunctionGraph([x], [out])
108+
assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None
109+
110+
out = x[x[:, 0] < 0, 0].sum()
111+
fg = FunctionGraph([x], [out])
112+
assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None
100113

101114

102115
def test_jax_IncSubtensor():

0 commit comments

Comments
 (0)