@@ -137,31 +137,29 @@ def clip(a : ArrayLike, min : Optional[ArrayLike]=None, max : Optional[ArrayLike
137
137
return _helpers .result_or_out (result , out )
138
138
139
139
140
- def repeat (a , repeats , axis = None ):
141
- tensor , t_repeats = _helpers .to_tensors (a , repeats ) # XXX: scalar repeats
142
- result = torch .repeat_interleave (tensor , t_repeats , axis )
140
+ @normalizer
141
+ def repeat (a : ArrayLike , repeats : ArrayLike , axis = None ):
142
+ # XXX: scalar repeats; ArrayLikeOrScalar ?
143
+ result = torch .repeat_interleave (a , repeats , axis )
143
144
return _helpers .array_from (result )
144
145
145
146
146
147
# ### diag et al ###
147
148
148
-
149
- def diagonal (a , offset = 0 , axis1 = 0 , axis2 = 1 ):
150
- (tensor ,) = _helpers .to_tensors (a )
151
- result = _impl .diagonal (tensor , offset , axis1 , axis2 )
149
+ @normalizer
150
+ def diagonal (a : ArrayLike , offset = 0 , axis1 = 0 , axis2 = 1 ):
151
+ result = _impl .diagonal (a , offset , axis1 , axis2 )
152
152
return _helpers .array_from (result )
153
153
154
154
155
155
@normalizer
156
156
def trace (a : ArrayLike , offset = 0 , axis1 = 0 , axis2 = 1 , dtype : DTypeLike = None , out = None ):
157
- # (tensor,) = _helpers.to_tensors(a)
158
157
result = _impl .trace (a , offset , axis1 , axis2 , dtype )
159
158
return _helpers .result_or_out (result , out )
160
159
161
160
162
161
@normalizer
163
162
def eye (N , M = None , k = 0 , dtype : DTypeLike = float , order = "C" , * , like : SubokLike = None ):
164
- # _util.subok_not_ok(like)
165
163
if order != "C" :
166
164
raise NotImplementedError
167
165
result = _impl .eye (N , M , k , dtype )
@@ -170,20 +168,19 @@ def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike
170
168
171
169
@normalizer
172
170
def identity (n , dtype : DTypeLike = None , * , like : SubokLike = None ):
173
- ## _util.subok_not_ok(like)
174
171
result = torch .eye (n , dtype = dtype )
175
172
return _helpers .array_from (result )
176
173
177
174
178
- def diag ( v , k = 0 ):
179
- ( tensor ,) = _helpers . to_tensors ( v )
180
- result = torch .diag (tensor , k )
175
+ @ normalizer
176
+ def diag ( v : ArrayLike , k = 0 ):
177
+ result = torch .diag (v , k )
181
178
return _helpers .array_from (result )
182
179
183
180
184
- def diagflat ( v , k = 0 ):
185
- ( tensor ,) = _helpers . to_tensors ( v )
186
- result = torch .diagflat (tensor , k )
181
+ @ normalizer
182
+ def diagflat ( v : ArrayLike , k = 0 ):
183
+ result = torch .diagflat (v , k )
187
184
return _helpers .array_from (result )
188
185
189
186
@@ -192,124 +189,126 @@ def diag_indices(n, ndim=2):
192
189
return _helpers .tuple_arrays_from (result )
193
190
194
191
195
- def diag_indices_from ( arr ):
196
- ( tensor ,) = _helpers . to_tensors ( arr )
197
- result = _impl .diag_indices_from (tensor )
192
+ @ normalizer
193
+ def diag_indices_from ( arr : ArrayLike ):
194
+ result = _impl .diag_indices_from (arr )
198
195
return _helpers .tuple_arrays_from (result )
199
196
200
197
201
- def fill_diagonal ( a , val , wrap = False ):
202
- tensor , t_val = _helpers . to_tensors ( a , val )
203
- result = _impl .fill_diagonal (tensor , t_val , wrap )
198
+ @ normalizer
199
+ def fill_diagonal ( a : ArrayLike , val : ArrayLike , wrap = False ):
200
+ result = _impl .fill_diagonal (a , val , wrap )
204
201
return _helpers .array_from (result )
205
202
206
203
207
- def vdot (a , b , / ):
208
- t_a , t_b = _helpers .to_tensors (a , b )
209
- result = _impl .vdot (t_a , t_b )
204
+ @normalizer
205
+ def vdot (a : ArrayLike , b : ArrayLike , / ):
206
+ # t_a, t_b = _helpers.to_tensors(a, b)
207
+ result = _impl .vdot (a , b )
210
208
return result .item ()
211
209
212
210
213
- def dot (a , b , out = None ):
214
- t_a , t_b = _helpers .to_tensors (a , b )
215
- result = _impl .dot (t_a , t_b )
211
+ @normalizer
212
+ def dot (a : ArrayLike , b : ArrayLike , out = None ):
213
+ # t_a, t_b = _helpers.to_tensors(a, b)
214
+ result = _impl .dot (a , b )
216
215
return _helpers .result_or_out (result , out )
217
216
218
217
219
218
# ### sort and partition ###
220
219
221
220
222
- def sort ( a , axis = - 1 , kind = None , order = None ):
223
- ( tensor ,) = _helpers . to_tensors ( a )
224
- result = _impl .sort (tensor , axis , kind , order )
221
+ @ normalizer
222
+ def sort ( a : ArrayLike , axis = - 1 , kind = None , order = None ):
223
+ result = _impl .sort (a , axis , kind , order )
225
224
return _helpers .array_from (result )
226
225
227
226
228
- def argsort ( a , axis = - 1 , kind = None , order = None ):
229
- ( tensor ,) = _helpers . to_tensors ( a )
230
- result = _impl .argsort (tensor , axis , kind , order )
227
+ @ normalizer
228
+ def argsort ( a : ArrayLike , axis = - 1 , kind = None , order = None ):
229
+ result = _impl .argsort (a , axis , kind , order )
231
230
return _helpers .array_from (result )
232
231
233
232
234
- def searchsorted ( a , v , side = "left" , sorter = None ):
235
- a_t , v_t , sorter_t = _helpers . to_tensors_or_none ( a , v , sorter )
236
- result = torch .searchsorted (a_t , v_t , side = side , sorter = sorter_t )
233
+ @ normalizer
234
+ def searchsorted ( a : ArrayLike , v : ArrayLike , side = "left" , sorter : Optional [ ArrayLike ] = None ):
235
+ result = torch .searchsorted (a , v , side = side , sorter = sorter )
237
236
return _helpers .array_from (result )
238
237
239
238
240
239
# ### swap/move/roll axis ###
241
240
242
241
243
- def moveaxis ( a , source , destination ):
244
- ( tensor ,) = _helpers . to_tensors ( a )
245
- result = _impl .moveaxis (tensor , source , destination )
242
+ @ normalizer
243
+ def moveaxis ( a : ArrayLike , source , destination ):
244
+ result = _impl .moveaxis (a , source , destination )
246
245
return _helpers .array_from (result )
247
246
248
247
249
- def swapaxes ( a , axis1 , axis2 ):
250
- ( tensor ,) = _helpers . to_tensors ( a )
251
- result = _flips .swapaxes (tensor , axis1 , axis2 )
248
+ @ normalizer
249
+ def swapaxes ( a : ArrayLike , axis1 , axis2 ):
250
+ result = _flips .swapaxes (a , axis1 , axis2 )
252
251
return _helpers .array_from (result )
253
252
254
253
255
- def rollaxis ( a , axis , start = 0 ):
256
- ( tensor ,) = _helpers . to_tensors ( a )
254
+ @ normalizer
255
+ def rollaxis ( a : ArrayLike , axis , start = 0 ):
257
256
result = _flips .rollaxis (a , axis , start )
258
257
return _helpers .array_from (result )
259
258
260
259
261
260
# ### shape manipulations ###
262
261
263
262
264
- def squeeze ( a , axis = None ):
265
- ( tensor ,) = _helpers . to_tensors ( a )
266
- result = _impl .squeeze (tensor , axis )
263
+ @ normalizer
264
+ def squeeze ( a : ArrayLike , axis = None ):
265
+ result = _impl .squeeze (a , axis )
267
266
return _helpers .array_from (result , a )
268
267
269
268
270
- def reshape ( a , newshape , order = "C" ):
271
- ( tensor ,) = _helpers . to_tensors ( a )
272
- result = _impl .reshape (tensor , newshape , order = order )
269
+ @ normalizer
270
+ def reshape ( a : ArrayLike , newshape , order = "C" ):
271
+ result = _impl .reshape (a , newshape , order = order )
273
272
return _helpers .array_from (result , a )
274
273
275
274
276
- def transpose ( a , axes = None ):
277
- ( tensor ,) = _helpers . to_tensors ( a )
278
- result = _impl .transpose (tensor , axes )
275
+ @ normalizer
276
+ def transpose ( a : ArrayLike , axes = None ):
277
+ result = _impl .transpose (a , axes )
279
278
return _helpers .array_from (result , a )
280
279
281
280
282
- def ravel ( a , order = "C" ):
283
- ( tensor ,) = _helpers . to_tensors ( a )
284
- result = _impl .ravel (tensor )
281
+ @ normalizer
282
+ def ravel ( a : ArrayLike , order = "C" ):
283
+ result = _impl .ravel (a )
285
284
return _helpers .array_from (result , a )
286
285
287
286
288
287
# leading underscore since arr.flatten exists but np.flatten does not
289
- def _flatten ( a , order = "C" ):
290
- ( tensor ,) = _helpers . to_tensors ( a )
291
- result = _impl ._flatten (tensor )
288
+ @ normalizer
289
+ def _flatten ( a : ArrayLike , order = "C" ):
290
+ result = _impl ._flatten (a )
292
291
return _helpers .array_from (result , a )
293
292
294
293
295
294
# ### Type/shape etc queries ###
296
295
297
296
298
- def real ( a ):
299
- ( tensor ,) = _helpers . to_tensors ( a )
300
- result = torch .real (tensor )
297
+ @ normalizer
298
+ def real ( a : ArrayLike ):
299
+ result = torch .real (a )
301
300
return _helpers .array_from (result )
302
301
303
302
304
- def imag ( a ):
305
- ( tensor ,) = _helpers . to_tensors ( a )
306
- result = _impl .imag (tensor )
303
+ @ normalizer
304
+ def imag ( a : ArrayLike ):
305
+ result = _impl .imag (a )
307
306
return _helpers .array_from (result )
308
307
309
308
310
- def round_ ( a , decimals = 0 , out = None ):
311
- ( tensor ,) = _helpers . to_tensors ( a )
312
- result = _impl .round (tensor , decimals )
309
+ @ normalizer
310
+ def round_ ( a : ArrayLike , decimals = 0 , out = None ):
311
+ result = _impl .round (a , decimals )
313
312
return _helpers .result_or_out (result , out )
314
313
315
314
0 commit comments