@@ -125,13 +125,6 @@ def params_type(self):
125
125
input_ndim = int64 ,
126
126
)
127
127
128
- @property
129
- def _new_order (self ):
130
- # Param for C code.
131
- # self.new_order may contain 'x', which is not a valid integer value.
132
- # We replace it with -1.
133
- return [(- 1 if x == "x" else x ) for x in self .new_order ]
134
-
135
128
def __init__ (self , * , input_ndim : int , new_order : Sequence [int | Literal ["x" ]]):
136
129
super ().__init__ ([self .c_func_file ], self .c_func_name )
137
130
@@ -140,6 +133,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
140
133
141
134
self .input_ndim = input_ndim
142
135
self .new_order = tuple (new_order )
136
+ self ._new_order = [(- 1 if x == "x" else x ) for x in self .new_order ]
143
137
144
138
for i , j in enumerate (new_order ):
145
139
if j != "x" :
@@ -231,22 +225,15 @@ def __str__(self):
231
225
return f"DimShuffle{{order=[{ ',' .join (map (str , self .new_order ))} ]}}"
232
226
233
227
def perform (self , node , inp , out ):
234
- (res ,) = inp
235
- (storage ,) = out
236
-
237
- if not isinstance (res , np .ndarray | np .memmap ):
238
- raise TypeError (res )
239
-
240
- # Put dropped axis at end
241
- res = res .transpose (self .transposition )
242
-
243
- # Define new shape without dropped axis and including new ones
244
- new_shape = list (res .shape [: len (self .shuffle )])
245
- for augm in self .augment :
246
- new_shape .insert (augm , 1 )
247
- res = res .reshape (new_shape )
248
-
249
- storage [0 ] = np .asarray (res )
228
+ (inp ,) = inp
229
+ new_order = self ._new_order
230
+ old_shape = inp .shape
231
+ old_strides = inp .strides
232
+
233
+ res = inp .view ()
234
+ res .shape = [1 if i == - 1 else old_shape [i ] for i in new_order ]
235
+ res .strides = [0 if i == - 1 else old_strides [i ] for i in new_order ]
236
+ out [0 ][0 ] = res
250
237
251
238
def infer_shape (self , fgraph , node , shapes ):
252
239
(ishp ,) = shapes
0 commit comments