Skip to content

Commit 60bc368

Browse files
committed
Use Dimshuffle for expand_dims
1 parent 652d0b6 commit 60bc368

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytensor/tensor/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4013,10 +4013,10 @@ def expand_dims(
40134013
out_ndim = len(axis) + a.ndim
40144014
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim)
40154015

4016-
shape_it = iter(a.shape)
4017-
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
4016+
dim_it = iter(range(a.ndim))
4017+
pattern = ["x" if ax in axis else next(dim_it) for ax in range(out_ndim)]
40184018

4019-
return a.reshape(shape)
4019+
return a.dimshuffle(pattern)
40204020

40214021

40224022
def _make_along_axis_idx(arr_shape, indices, axis):

0 commit comments

Comments
 (0)