@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
130
130
if isinstance (idx .type , TensorType )
131
131
]
132
132
133
- def broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
134
- # Check that x is not broadcasted to y based on broadcastable info
135
- if len (x_bcast ) < len (to_bcast ):
136
- return True
137
- for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
138
- if x_bcast_dim and not to_bcast_dim :
139
- return True
140
- return False
141
-
142
133
# Special implementation for consecutive integer vector indices
143
134
if (
144
135
not basic_idxs
@@ -151,17 +142,6 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
151
142
)
152
143
# Must be consecutive
153
144
and not op .non_consecutive_adv_indexing (node )
154
- # y in set/inc_subtensor cannot be broadcasted
155
- and (
156
- y is None
157
- or not broadcasted_to (
158
- y .type .broadcastable ,
159
- (
160
- x .type .broadcastable [: adv_idxs [0 ]["axis" ]]
161
- + x .type .broadcastable [adv_idxs [- 1 ]["axis" ] :]
162
- ),
163
- )
164
- )
165
145
):
166
146
return numba_funcify_multiple_integer_vector_indexing (op , node , ** kwargs )
167
147
@@ -191,14 +171,24 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
191
171
return numba_funcify_default_subtensor (op , node , ** kwargs )
192
172
193
173
174
+ def _broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
175
+ # Check that x is not broadcasted to y based on broadcastable info
176
+ if len (x_bcast ) < len (to_bcast ):
177
+ return True
178
+ for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
179
+ if x_bcast_dim and not to_bcast_dim :
180
+ return True
181
+ return False
182
+
183
+
194
184
def numba_funcify_multiple_integer_vector_indexing (
195
185
op : AdvancedSubtensor | AdvancedIncSubtensor , node , ** kwargs
196
186
):
197
187
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
198
188
if isinstance (op , AdvancedSubtensor ):
199
- y , idxs = None , node .inputs [1 :]
189
+ idxs = node .inputs [1 :]
200
190
else :
201
- y , * idxs = node .inputs [1 :]
191
+ idxs = node .inputs [2 :]
202
192
203
193
first_axis = next (
204
194
i for i , idx in enumerate (idxs ) if isinstance (idx .type , TensorType )
@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing(
211
201
)
212
202
except StopIteration :
213
203
after_last_axis = len (idxs )
204
+ last_axis = after_last_axis - 1
205
+
206
+ vector_indices = idxs [first_axis :after_last_axis ]
207
+ assert all (v .type .broadcastable == (False ,) for v in vector_indices )
214
208
215
209
if isinstance (op , AdvancedSubtensor ):
216
210
@@ -231,43 +225,59 @@ def advanced_subtensor_multiple_vector(x, *idxs):
231
225
232
226
return advanced_subtensor_multiple_vector
233
227
234
- elif op . set_instead_of_inc :
228
+ else :
235
229
inplace = op .inplace
236
230
237
- @numba_njit
238
- def advanced_set_subtensor_multiple_vector (x , y , * idxs ):
239
- vec_idxs = idxs [first_axis :after_last_axis ]
240
- x_shape = x .shape
231
+ # Check if y must be broadcasted
232
+ # Includes the last integer vector index,
233
+ x , y = node .inputs [:2 ]
234
+ indexed_bcast_dims = (
235
+ * x .type .broadcastable [:first_axis ],
236
+ * x .type .broadcastable [last_axis :],
237
+ )
238
+ y_is_broadcasted = _broadcasted_to (y .type .broadcastable , indexed_bcast_dims )
241
239
242
- if inplace :
243
- out = x
244
- else :
245
- out = x .copy ()
240
+ if op .set_instead_of_inc :
246
241
247
- for outer in np . ndindex ( x_shape [: first_axis ]):
248
- for i , scalar_idxs in enumerate ( zip ( * vec_idxs )): # noqa: B905
249
- out [( * outer , * scalar_idxs )] = y [( * outer , i ) ]
250
- return out
242
+ @ numba_njit
243
+ def advanced_set_subtensor_multiple_vector ( x , y , * idxs ):
244
+ vec_idxs = idxs [ first_axis : after_last_axis ]
245
+ x_shape = x . shape
251
246
252
- return advanced_set_subtensor_multiple_vector
247
+ if inplace :
248
+ out = x
249
+ else :
250
+ out = x .copy ()
253
251
254
- else :
255
- inplace = op . inplace
252
+ if y_is_broadcasted :
253
+ y = np . broadcast_to ( y , x_shape [: first_axis ] + x_shape [ last_axis :])
256
254
257
- @numba_njit
258
- def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
259
- vec_idxs = idxs [first_axis :after_last_axis ]
260
- x_shape = x .shape
255
+ for outer in np .ndindex (x_shape [:first_axis ]):
256
+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
257
+ out [(* outer , * scalar_idxs )] = y [(* outer , i )]
258
+ return out
259
+
260
+ return advanced_set_subtensor_multiple_vector
261
+
262
+ else :
263
+
264
+ @numba_njit
265
+ def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
266
+ vec_idxs = idxs [first_axis :after_last_axis ]
267
+ x_shape = x .shape
268
+
269
+ if inplace :
270
+ out = x
271
+ else :
272
+ out = x .copy ()
261
273
262
- if inplace :
263
- out = x
264
- else :
265
- out = x .copy ()
274
+ if y_is_broadcasted :
275
+ y = np .broadcast_to (y , x_shape [:first_axis ] + x_shape [last_axis :])
266
276
267
- for outer in np .ndindex (x_shape [:first_axis ]):
268
- for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
269
- out [(* outer , * scalar_idxs )] += y [(* outer , i )]
270
- return out
277
+ for outer in np .ndindex (x_shape [:first_axis ]):
278
+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
279
+ out [(* outer , * scalar_idxs )] += y [(* outer , i )]
280
+ return out
271
281
272
282
return advanced_inc_subtensor_multiple_vector
273
283
0 commit comments