Skip to content

Commit 0de0fa9

Browse files
committed
Allow broadcasting in specialized numba dispatch of AdvancedIncSubtensor
1 parent 52bbf59 commit 0de0fa9

File tree

2 files changed

+62
-52
lines changed

2 files changed

+62
-52
lines changed

pytensor/link/numba/dispatch/subtensor.py

+60-50
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
130130
if isinstance(idx.type, TensorType)
131131
]
132132

133-
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
134-
# Check that x is not broadcasted to y based on broadcastable info
135-
if len(x_bcast) < len(to_bcast):
136-
return True
137-
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
138-
if x_bcast_dim and not to_bcast_dim:
139-
return True
140-
return False
141-
142133
# Special implementation for consecutive integer vector indices
143134
if (
144135
not basic_idxs
@@ -151,17 +142,6 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
151142
)
152143
# Must be consecutive
153144
and not op.non_consecutive_adv_indexing(node)
154-
# y in set/inc_subtensor cannot be broadcasted
155-
and (
156-
y is None
157-
or not broadcasted_to(
158-
y.type.broadcastable,
159-
(
160-
x.type.broadcastable[: adv_idxs[0]["axis"]]
161-
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
162-
),
163-
)
164-
)
165145
):
166146
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)
167147

@@ -191,14 +171,24 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
191171
return numba_funcify_default_subtensor(op, node, **kwargs)
192172

193173

174+
def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
175+
# Check that x is not broadcasted to y based on broadcastable info
176+
if len(x_bcast) < len(to_bcast):
177+
return True
178+
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
179+
if x_bcast_dim and not to_bcast_dim:
180+
return True
181+
return False
182+
183+
194184
def numba_funcify_multiple_integer_vector_indexing(
195185
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
196186
):
197187
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
198188
if isinstance(op, AdvancedSubtensor):
199-
y, idxs = None, node.inputs[1:]
189+
idxs = node.inputs[1:]
200190
else:
201-
y, *idxs = node.inputs[1:]
191+
idxs = node.inputs[2:]
202192

203193
first_axis = next(
204194
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
@@ -211,6 +201,10 @@ def numba_funcify_multiple_integer_vector_indexing(
211201
)
212202
except StopIteration:
213203
after_last_axis = len(idxs)
204+
last_axis = after_last_axis - 1
205+
206+
vector_indices = idxs[first_axis:after_last_axis]
207+
assert all(v.type.broadcastable == (False,) for v in vector_indices)
214208

215209
if isinstance(op, AdvancedSubtensor):
216210

@@ -231,43 +225,59 @@ def advanced_subtensor_multiple_vector(x, *idxs):
231225

232226
return advanced_subtensor_multiple_vector
233227

234-
elif op.set_instead_of_inc:
228+
else:
235229
inplace = op.inplace
236230

237-
@numba_njit
238-
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
239-
vec_idxs = idxs[first_axis:after_last_axis]
240-
x_shape = x.shape
231+
# Check if y must be broadcasted
232+
# Includes the last integer vector index,
233+
x, y = node.inputs[:2]
234+
indexed_bcast_dims = (
235+
*x.type.broadcastable[:first_axis],
236+
*x.type.broadcastable[last_axis:],
237+
)
238+
y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims)
241239

242-
if inplace:
243-
out = x
244-
else:
245-
out = x.copy()
240+
if op.set_instead_of_inc:
246241

247-
for outer in np.ndindex(x_shape[:first_axis]):
248-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
249-
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
250-
return out
242+
@numba_njit
243+
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
244+
vec_idxs = idxs[first_axis:after_last_axis]
245+
x_shape = x.shape
251246

252-
return advanced_set_subtensor_multiple_vector
247+
if inplace:
248+
out = x
249+
else:
250+
out = x.copy()
253251

254-
else:
255-
inplace = op.inplace
252+
if y_is_broadcasted:
253+
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
256254

257-
@numba_njit
258-
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
259-
vec_idxs = idxs[first_axis:after_last_axis]
260-
x_shape = x.shape
255+
for outer in np.ndindex(x_shape[:first_axis]):
256+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
257+
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
258+
return out
259+
260+
return advanced_set_subtensor_multiple_vector
261+
262+
else:
263+
264+
@numba_njit
265+
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
266+
vec_idxs = idxs[first_axis:after_last_axis]
267+
x_shape = x.shape
268+
269+
if inplace:
270+
out = x
271+
else:
272+
out = x.copy()
261273

262-
if inplace:
263-
out = x
264-
else:
265-
out = x.copy()
274+
if y_is_broadcasted:
275+
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
266276

267-
for outer in np.ndindex(x_shape[:first_axis]):
268-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
269-
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
270-
return out
277+
for outer in np.ndindex(x_shape[:first_axis]):
278+
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
279+
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
280+
return out
271281

272282
return advanced_inc_subtensor_multiple_vector
273283

tests/link/numba/test_subtensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ def test_AdvancedIncSubtensor1(x, y, indices):
392392
np.array(-99), # Broadcasted value
393393
([1, 2], [2, 3]), # 2 vector indices
394394
False,
395-
True,
396-
True,
395+
False,
396+
False,
397397
),
398398
(
399399
np.arange(3 * 4 * 5).reshape((3, 4, 5)),

0 commit comments

Comments
 (0)