Skip to content

Commit 33948f6

Browse files
committed
Simplify Python DimShuffle implementation
1 parent 2652d0f commit 33948f6

File tree

1 file changed

+10
-23
lines changed

1 file changed

+10
-23
lines changed

pytensor/tensor/elemwise.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,6 @@ def params_type(self):
125125
input_ndim=int64,
126126
)
127127

128-
@property
129-
def _new_order(self):
130-
# Param for C code.
131-
# self.new_order may contain 'x', which is not a valid integer value.
132-
# We replace it with -1.
133-
return [(-1 if x == "x" else x) for x in self.new_order]
134-
135128
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
136129
super().__init__([self.c_func_file], self.c_func_name)
137130

@@ -140,6 +133,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
140133

141134
self.input_ndim = input_ndim
142135
self.new_order = tuple(new_order)
136+
self._new_order = [(-1 if x == "x" else x) for x in self.new_order]
143137

144138
for i, j in enumerate(new_order):
145139
if j != "x":
@@ -231,22 +225,15 @@ def __str__(self):
231225
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
232226

233227
def perform(self, node, inp, out):
234-
(res,) = inp
235-
(storage,) = out
236-
237-
if not isinstance(res, np.ndarray | np.memmap):
238-
raise TypeError(res)
239-
240-
# Put dropped axis at end
241-
res = res.transpose(self.transposition)
242-
243-
# Define new shape without dropped axis and including new ones
244-
new_shape = list(res.shape[: len(self.shuffle)])
245-
for augm in self.augment:
246-
new_shape.insert(augm, 1)
247-
res = res.reshape(new_shape)
248-
249-
storage[0] = np.asarray(res)
228+
(inp,) = inp
229+
new_order = self._new_order
230+
old_shape = inp.shape
231+
old_strides = inp.strides
232+
233+
res = inp.view()
234+
res.shape = [1 if i == -1 else old_shape[i] for i in new_order]
235+
res.strides = [0 if i == -1 else old_strides[i] for i in new_order]
236+
out[0][0] = res
250237

251238
def infer_shape(self, fgraph, node, shapes):
252239
(ishp,) = shapes

0 commit comments

Comments
 (0)