Skip to content

Commit 0087e56

Browse files
rloufricardoV94
authored andcommitted
Add rewrites to re-express boolean indexing logic
1 parent 3747422 commit 0087e56

File tree

3 files changed

+119
-32
lines changed

3 files changed

+119
-32
lines changed

pytensor/tensor/rewriting/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytensor.tensor.rewriting.basic
22
import pytensor.tensor.rewriting.elemwise
33
import pytensor.tensor.rewriting.extra_ops
4+
5+
# Register JAX specializations
6+
import pytensor.tensor.rewriting.jax
47
import pytensor.tensor.rewriting.math
58
import pytensor.tensor.rewriting.shape
69
import pytensor.tensor.rewriting.special

pytensor/tensor/rewriting/jax.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import in2out, node_rewriter
3+
from pytensor.tensor.var import TensorVariable
4+
import pytensor.tensor as at
5+
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
6+
from pytensor.tensor.math import Sum
7+
8+
9+
@node_rewriter([AdvancedIncSubtensor])
10+
def boolean_indexing_set_or_inc(fgraph, node):
11+
"""Replace `AdvancedIncSubtensor` when using boolean indexing using `Switch`.
12+
13+
JAX cannot JIT-compile functions that use boolean indexing to set values in
14+
an array. A workaround is to re-express this logic using `jax.numpy.where`.
15+
This rewrite allows to improve upon JAX's API.
16+
17+
"""
18+
19+
op = node.op
20+
x = node.inputs[0]
21+
y = node.inputs[1]
22+
cond = node.inputs[2]
23+
24+
if not isinstance(cond, TensorVariable):
25+
return
26+
27+
if not cond.type.dtype == 'bool':
28+
return
29+
30+
if op.set_instead_of_inc:
31+
out = at.where(cond, y, x)
32+
return out.owner.outputs
33+
else:
34+
out = at.where(cond, x + y, x)
35+
return out.owner.outputs
36+
37+
38+
optdb.register(
39+
"jax_boolean_indexing_set_or_inc", in2out(boolean_indexing_set_or_inc), "jax", position=100
40+
)
41+
42+
43+
@node_rewriter([Sum])
44+
def boolean_indexing_sum(fgraph, node):
45+
"""Replace the sum of `AdvancedSubtensor` with boolean indexing.
46+
47+
JAX cannot JIT-compile functions that use boolean indexing, but can compile
48+
those expressions that can be re-expressed using `jax.numpy.where`. This
49+
rewrite re-rexpressed the model on the behalf of the user and thus allows to
50+
improve upon JAX's API.
51+
52+
"""
53+
operand = node.inputs[0]
54+
55+
if not isinstance(operand, TensorVariable):
56+
return
57+
58+
if operand.owner is None:
59+
return
60+
61+
if not isinstance(operand.owner.op, AdvancedSubtensor):
62+
return
63+
64+
x = operand.owner.inputs[0]
65+
cond = operand.owner.inputs[1]
66+
67+
if not isinstance(cond, TensorVariable):
68+
return
69+
70+
if not cond.type.dtype == 'bool':
71+
return
72+
73+
out = at.sum(at.where(cond, x, 0))
74+
return out.owner.outputs
75+
76+
optdb.register(
77+
"jax_boolean_indexing_sum", in2out(boolean_indexing_sum), "jax", position=100
78+
)

tests/link/jax/test_subtensor.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,21 @@ def test_jax_Subtensor_boolean_mask():
8080
compare_jax_and_py(out_fg, [])
8181

8282

83-
@pytest.mark.xfail(
84-
reason="Re-expressible boolean logic. We need a rewrite PyTensor-side."
85-
)
8683
def test_jax_Subtensor_boolean_mask_reexpressible():
87-
"""Some boolean logic can be re-expressed and JIT-compiled"""
88-
x_at = at.arange(-5, 5)
84+
"""Summing values with boolean indexing.
85+
86+
This test ensures that the sum of an `AdvancedSubtensor` `Op`s with boolean
87+
indexing is replaced with the sum of an equivalent `Switch` `Op`, using the
88+
`jax_boolean_indexing_sum` rewrite.
89+
90+
JAX forces users to re-express this logic manually, so this is an
91+
improvement over its user interface.
92+
93+
"""
94+
x_at = at.vector("x")
8995
out_at = x_at[x_at < 0].sum()
90-
out_fg = FunctionGraph([], [out_at])
91-
compare_jax_and_py(out_fg, [])
96+
out_fg = FunctionGraph([x_at], [out_at])
97+
compare_jax_and_py(out_fg, [np.arange(-5, 5).astype(config.floatX)])
9298

9399

94100
def test_jax_IncSubtensor():
@@ -177,42 +183,42 @@ def test_jax_IncSubtensor():
177183
out_fg = FunctionGraph([], [out_at])
178184
compare_jax_and_py(out_fg, [])
179185

180-
181-
@pytest.mark.xfail(
182-
reason="Re-expressible boolean logic. We need a rewrite PyTensor-side to remove the DimShuffle."
183-
)
184-
def test_jax_IncSubtensor_boolean_mask_reexpressible():
185-
"""Some boolean logic can be re-expressed and JIT-compiled"""
186-
rng = np.random.default_rng(213234)
187-
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
188-
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
189-
190-
mask_at = at.as_tensor(x_np) > 0
191-
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
186+
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
187+
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
192188
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
193189
out_fg = FunctionGraph([], [out_at])
194190
compare_jax_and_py(out_fg, [])
195191

196-
mask_at = at.as_tensor(x_np) > 0
197-
out_at = at_subtensor.inc_subtensor(x_at[mask_at], 1.0)
192+
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
193+
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
198194
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
199195
out_fg = FunctionGraph([], [out_at])
200196
compare_jax_and_py(out_fg, [])
201197

202198

203-
def test_jax_IncSubtensors_unsupported():
199+
def test_jax_IncSubtensor_boolean_indexing_reexpressible():
200+
"""Setting or incrementing values with boolean indexing.
201+
202+
This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
203+
replaced with an equivalent `Switch` `Op`, using the
204+
`jax_boolean_indexing_set_of_inc` rewrite.
205+
206+
JAX forces users to re-express this logic manually, so this is an
207+
improvement over its user interface.
208+
209+
"""
204210
rng = np.random.default_rng(213234)
205-
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
206-
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
211+
x_np = rng.uniform(-1, 1, size=(4, 5)).astype(config.floatX)
207212

208-
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
209-
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
213+
x_at = at.matrix("x")
214+
mask_at = at.as_tensor(x_at) > 0
215+
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
210216
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
211-
out_fg = FunctionGraph([], [out_at])
212-
compare_jax_and_py(out_fg, [])
217+
out_fg = FunctionGraph([x_at], [out_at])
218+
compare_jax_and_py(out_fg, [x_np])
213219

214-
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
215-
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
220+
mask_at = at.as_tensor(x_at) > 0
221+
out_at = at_subtensor.inc_subtensor(x_at[mask_at], 1.0)
216222
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
217-
out_fg = FunctionGraph([], [out_at])
218-
compare_jax_and_py(out_fg, [])
223+
out_fg = FunctionGraph([x_at], [out_at])
224+
compare_jax_and_py(out_fg, [x_np])

0 commit comments

Comments
 (0)