@@ -80,15 +80,21 @@ def test_jax_Subtensor_boolean_mask():
80
80
compare_jax_and_py (out_fg , [])
81
81
82
82
83
- @pytest .mark .xfail (
84
- reason = "Re-expressible boolean logic. We need a rewrite PyTensor-side."
85
- )
86
83
def test_jax_Subtensor_boolean_mask_reexpressible ():
87
- """Some boolean logic can be re-expressed and JIT-compiled"""
88
- x_at = at .arange (- 5 , 5 )
84
+ """Summing values with boolean indexing.
85
+
86
+ This test ensures that the sum of an `AdvancedSubtensor` `Op`s with boolean
87
+ indexing is replaced with the sum of an equivalent `Switch` `Op`, using the
88
+ `jax_boolean_indexing_sum` rewrite.
89
+
90
+ JAX forces users to re-express this logic manually, so this is an
91
+ improvement over its user interface.
92
+
93
+ """
94
+ x_at = at .vector ("x" )
89
95
out_at = x_at [x_at < 0 ].sum ()
90
- out_fg = FunctionGraph ([], [out_at ])
91
- compare_jax_and_py (out_fg , [])
96
+ out_fg = FunctionGraph ([x_at ], [out_at ])
97
+ compare_jax_and_py (out_fg , [np . arange ( - 5 , 5 ). astype ( config . floatX ) ])
92
98
93
99
94
100
def test_jax_IncSubtensor ():
@@ -177,42 +183,42 @@ def test_jax_IncSubtensor():
177
183
out_fg = FunctionGraph ([], [out_at ])
178
184
compare_jax_and_py (out_fg , [])
179
185
180
-
181
- @pytest .mark .xfail (
182
- reason = "Re-expressible boolean logic. We need a rewrite PyTensor-side to remove the DimShuffle."
183
- )
184
- def test_jax_IncSubtensor_boolean_mask_reexpressible ():
185
- """Some boolean logic can be re-expressed and JIT-compiled"""
186
- rng = np .random .default_rng (213234 )
187
- x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
188
- x_at = at .constant (np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX ))
189
-
190
- mask_at = at .as_tensor (x_np ) > 0
191
- out_at = at_subtensor .set_subtensor (x_at [mask_at ], 0.0 )
186
+ st_at = at .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
187
+ out_at = at_subtensor .set_subtensor (x_at [[0 , 2 ], 0 , :3 ], st_at )
192
188
assert isinstance (out_at .owner .op , at_subtensor .AdvancedIncSubtensor )
193
189
out_fg = FunctionGraph ([], [out_at ])
194
190
compare_jax_and_py (out_fg , [])
195
191
196
- mask_at = at .as_tensor (x_np ) > 0
197
- out_at = at_subtensor .inc_subtensor (x_at [mask_at ], 1.0 )
192
+ st_at = at .as_tensor_variable (x_np [[ 0 , 2 ], 0 , : 3 ])
193
+ out_at = at_subtensor .inc_subtensor (x_at [[ 0 , 2 ], 0 , : 3 ], st_at )
198
194
assert isinstance (out_at .owner .op , at_subtensor .AdvancedIncSubtensor )
199
195
out_fg = FunctionGraph ([], [out_at ])
200
196
compare_jax_and_py (out_fg , [])
201
197
202
198
203
- def test_jax_IncSubtensors_unsupported ():
199
+ def test_jax_IncSubtensor_boolean_indexing_reexpressible ():
200
+ """Setting or incrementing values with boolean indexing.
201
+
202
+ This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
203
+ replaced with an equivalent `Switch` `Op`, using the
204
+ `jax_boolean_indexing_set_of_inc` rewrite.
205
+
206
+ JAX forces users to re-express this logic manually, so this is an
207
+ improvement over its user interface.
208
+
209
+ """
204
210
rng = np .random .default_rng (213234 )
205
- x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
206
- x_at = at .constant (np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX ))
211
+ x_np = rng .uniform (- 1 , 1 , size = (4 , 5 )).astype (config .floatX )
207
212
208
- st_at = at .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
209
- out_at = at_subtensor .set_subtensor (x_at [[0 , 2 ], 0 , :3 ], st_at )
213
+ x_at = at .matrix ("x" )
214
+ mask_at = at .as_tensor (x_at ) > 0
215
+ out_at = at_subtensor .set_subtensor (x_at [mask_at ], 0.0 )
210
216
assert isinstance (out_at .owner .op , at_subtensor .AdvancedIncSubtensor )
211
- out_fg = FunctionGraph ([], [out_at ])
212
- compare_jax_and_py (out_fg , [])
217
+ out_fg = FunctionGraph ([x_at ], [out_at ])
218
+ compare_jax_and_py (out_fg , [x_np ])
213
219
214
- st_at = at .as_tensor_variable ( x_np [[ 0 , 2 ], 0 , : 3 ])
215
- out_at = at_subtensor .inc_subtensor (x_at [[ 0 , 2 ], 0 , : 3 ], st_at )
220
+ mask_at = at .as_tensor ( x_at ) > 0
221
+ out_at = at_subtensor .inc_subtensor (x_at [mask_at ], 1.0 )
216
222
assert isinstance (out_at .owner .op , at_subtensor .AdvancedIncSubtensor )
217
- out_fg = FunctionGraph ([], [out_at ])
218
- compare_jax_and_py (out_fg , [])
223
+ out_fg = FunctionGraph ([x_at ], [out_at ])
224
+ compare_jax_and_py (out_fg , [x_np ])
0 commit comments