5
5
from pytensor .configdefaults import config
6
6
from pytensor .graph .fg import FunctionGraph
7
7
from pytensor .tensor import subtensor as at_subtensor
8
- from pytensor .tensor .rewriting .jax import boolean_indexing_sum
8
+ from pytensor .tensor .rewriting .jax import (
9
+ boolean_indexing_set_or_inc ,
10
+ boolean_indexing_sum ,
11
+ )
9
12
from tests .link .jax .test_basic import compare_jax_and_py
10
13
11
14
@@ -216,7 +219,7 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
216
219
217
220
This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
218
221
replaced with an equivalent `Switch` `Op`, using the
219
- `jax_boolean_indexing_set_of_inc ` rewrite.
222
+ `boolean_indexing_set_of_inc ` rewrite.
220
223
221
224
JAX forces users to re-express this logic manually, so this is an
222
225
improvement over its user interface.
@@ -237,3 +240,12 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
237
240
assert isinstance (out_at .owner .op , at_subtensor .AdvancedIncSubtensor )
238
241
out_fg = FunctionGraph ([x_at ], [out_at ])
239
242
compare_jax_and_py (out_fg , [x_np ])
243
+
244
+
245
+ def test_boolean_indexing_set_or_inc_not_applicable ():
246
+ """Test that `boolean_indexing_set_or_inc` does not return an invalid replacement in cases where it doesn't apply."""
247
+ x = at .vector ("x" )
248
+ mask = at .as_tensor (x ) > 0
249
+ out = at_subtensor .set_subtensor (x [mask ], [0 , 1 , 2 ])
250
+ fg = FunctionGraph ([x ], [out ])
251
+ assert boolean_indexing_set_or_inc .transform (fg , fg .outputs [0 ].owner ) is None
0 commit comments