Skip to content

Commit 3747422

Browse files
rloufricardoV94
authored andcommitted
Simplify the IncSubtensor dispatcher
1 parent b3f12b2 commit 3747422

File tree

2 files changed

+29
-42
lines changed

2 files changed

+29
-42
lines changed

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import jax
2-
31
from pytensor.link.jax.dispatch.basic import jax_funcify
42
from pytensor.tensor.subtensor import (
53
AdvancedIncSubtensor,
@@ -33,7 +31,7 @@
3331
"""
3432

3533

36-
def assert_indices_jax_compatible(node, idx_list):
34+
def subtensor_assert_indices_jax_compatible(node, idx_list):
3735
from pytensor.graph.basic import Constant
3836
from pytensor.tensor.var import TensorVariable
3937

@@ -55,7 +53,7 @@ def assert_indices_jax_compatible(node, idx_list):
5553
def jax_funcify_Subtensor(op, node, **kwargs):
5654

5755
idx_list = getattr(op, "idx_list", None)
58-
assert_indices_jax_compatible(node, idx_list)
56+
subtensor_assert_indices_jax_compatible(node, idx_list)
5957

6058
def subtensor_constant(x, *ilists):
6159
indices = indices_from_subtensor(ilists, idx_list)
@@ -69,25 +67,19 @@ def subtensor_constant(x, *ilists):
6967

7068
@jax_funcify.register(IncSubtensor)
7169
@jax_funcify.register(AdvancedIncSubtensor1)
72-
def jax_funcify_IncSubtensor(op, **kwargs):
70+
def jax_funcify_IncSubtensor(op, node, **kwargs):
7371

7472
idx_list = getattr(op, "idx_list", None)
7573

7674
if getattr(op, "set_instead_of_inc", False):
77-
jax_fn = getattr(jax.ops, "index_update", None)
78-
79-
if jax_fn is None:
8075

81-
def jax_fn(x, indices, y):
82-
return x.at[indices].set(y)
76+
def jax_fn(x, indices, y):
77+
return x.at[indices].set(y)
8378

8479
else:
85-
jax_fn = getattr(jax.ops, "index_add", None)
86-
87-
if jax_fn is None:
8880

89-
def jax_fn(x, indices, y):
90-
return x.at[indices].add(y)
81+
def jax_fn(x, indices, y):
82+
return x.at[indices].add(y)
9183

9284
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
9385
indices = indices_from_subtensor(ilist, idx_list)
@@ -100,23 +92,17 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
10092

10193

10294
@jax_funcify.register(AdvancedIncSubtensor)
103-
def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
95+
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
10496

10597
if getattr(op, "set_instead_of_inc", False):
106-
jax_fn = getattr(jax.ops, "index_update", None)
10798

108-
if jax_fn is None:
109-
110-
def jax_fn(x, indices, y):
111-
return x.at[indices].set(y)
99+
def jax_fn(x, indices, y):
100+
return x.at[indices].set(y)
112101

113102
else:
114-
jax_fn = getattr(jax.ops, "index_add", None)
115-
116-
if jax_fn is None:
117103

118-
def jax_fn(x, indices, y):
119-
return x.at[indices].add(y)
104+
def jax_fn(x, indices, y):
105+
return x.at[indices].add(y)
120106

121107
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
122108
return jax_fn(x, ilist, y)

tests/link/jax/test_subtensor.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import pytest
3-
from jax._src.errors import NonConcreteBooleanIndexError
43

54
import pytensor.tensor as at
65
from pytensor.configdefaults import config
@@ -179,7 +178,11 @@ def test_jax_IncSubtensor():
179178
compare_jax_and_py(out_fg, [])
180179

181180

182-
def test_jax_IncSubtensors_unsupported():
181+
@pytest.mark.xfail(
182+
reason="Re-expressible boolean logic. We need a rewrite PyTensor-side to remove the DimShuffle."
183+
)
184+
def test_jax_IncSubtensor_boolean_mask_reexpressible():
185+
"""Some boolean logic can be re-expressed and JIT-compiled"""
183186
rng = np.random.default_rng(213234)
184187
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
185188
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
@@ -188,30 +191,28 @@ def test_jax_IncSubtensors_unsupported():
188191
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
189192
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
190193
out_fg = FunctionGraph([], [out_at])
191-
with pytest.raises(
192-
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
193-
):
194-
compare_jax_and_py(out_fg, [])
194+
compare_jax_and_py(out_fg, [])
195195

196-
mask_at = at.as_tensor_variable(x_np) > 0
197-
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
196+
mask_at = at.as_tensor(x_np) > 0
197+
out_at = at_subtensor.inc_subtensor(x_at[mask_at], 1.0)
198198
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
199199
out_fg = FunctionGraph([], [out_at])
200-
with pytest.raises(
201-
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
202-
):
203-
compare_jax_and_py(out_fg, [])
200+
compare_jax_and_py(out_fg, [])
201+
202+
203+
def test_jax_IncSubtensors_unsupported():
204+
rng = np.random.default_rng(213234)
205+
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
206+
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
204207

205208
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
206209
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
207210
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
208211
out_fg = FunctionGraph([], [out_at])
209-
with pytest.raises(IndexError, match="Array slice indices must have static"):
210-
compare_jax_and_py(out_fg, [])
212+
compare_jax_and_py(out_fg, [])
211213

212214
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
213215
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
214216
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
215217
out_fg = FunctionGraph([], [out_at])
216-
with pytest.raises(IndexError, match="Array slice indices must have static"):
217-
compare_jax_and_py(out_fg, [])
218+
compare_jax_and_py(out_fg, [])

0 commit comments

Comments
 (0)