@@ -56,7 +56,7 @@ def normalize_subok_like(arg, name):
56
56
ArrayLike : normalize_array_like ,
57
57
Optional [ArrayLike ]: normalize_optional_array_like ,
58
58
Sequence [ArrayLike ]: normalize_seq_array_like ,
59
- UnpackedSeqArrayLike : normalize_seq_array_like , # cf handling in normalize
59
+ UnpackedSeqArrayLike : normalize_seq_array_like , # cf handling in normalize
60
60
DTypeLike : normalize_dtype ,
61
61
SubokLike : normalize_subok_like ,
62
62
}
@@ -99,7 +99,6 @@ def wrapped(*args, **kwds):
99
99
print (arg , name , parm .annotation )
100
100
lst .append (normalize_this (arg , parm ))
101
101
102
-
103
102
# normalize keyword arguments
104
103
for name , arg in kwds .items ():
105
104
if not name in sig .parameters :
@@ -156,24 +155,30 @@ def argwhere(a):
156
155
157
156
158
157
@normalizer
159
- def clip (a : ArrayLike , min : Optional [ArrayLike ]= None , max : Optional [ArrayLike ]= None , out = None ):
158
+ def clip (
159
+ a : ArrayLike ,
160
+ min : Optional [ArrayLike ] = None ,
161
+ max : Optional [ArrayLike ] = None ,
162
+ out = None ,
163
+ ):
160
164
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
161
165
# one of them to be None. Follow the more lax version.
162
166
result = _impl .clip (a , min , max )
163
167
return _helpers .result_or_out (result , out )
164
168
165
169
166
170
@normalizer
167
- def repeat (a : ArrayLike , repeats : ArrayLike , axis = None ):
171
+ def repeat (a : ArrayLike , repeats : ArrayLike , axis = None ):
168
172
# XXX: scalar repeats; ArrayLikeOrScalar ?
169
173
result = torch .repeat_interleave (a , repeats , axis )
170
174
return _helpers .array_from (result )
171
175
172
176
173
177
# ### diag et al ###
174
178
179
+
175
180
@normalizer
176
- def diagonal (a : ArrayLike , offset = 0 , axis1 = 0 , axis2 = 1 ):
181
+ def diagonal (a : ArrayLike , offset = 0 , axis1 = 0 , axis2 = 1 ):
177
182
result = _impl .diagonal (a , offset , axis1 , axis2 )
178
183
return _helpers .array_from (result )
179
184
@@ -199,13 +204,13 @@ def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None):
199
204
200
205
201
206
@normalizer
202
- def diag (v : ArrayLike , k = 0 ):
207
+ def diag (v : ArrayLike , k = 0 ):
203
208
result = torch .diag (v , k )
204
209
return _helpers .array_from (result )
205
210
206
211
207
212
@normalizer
208
- def diagflat (v : ArrayLike , k = 0 ):
213
+ def diagflat (v : ArrayLike , k = 0 ):
209
214
result = torch .diagflat (v , k )
210
215
return _helpers .array_from (result )
211
216
@@ -216,27 +221,25 @@ def diag_indices(n, ndim=2):
216
221
217
222
218
223
@normalizer
219
- def diag_indices_from (arr : ArrayLike ):
224
+ def diag_indices_from (arr : ArrayLike ):
220
225
result = _impl .diag_indices_from (arr )
221
226
return _helpers .tuple_arrays_from (result )
222
227
223
228
224
229
@normalizer
225
- def fill_diagonal (a : ArrayLike , val : ArrayLike , wrap = False ):
230
+ def fill_diagonal (a : ArrayLike , val : ArrayLike , wrap = False ):
226
231
result = _impl .fill_diagonal (a , val , wrap )
227
232
return _helpers .array_from (result )
228
233
229
234
230
235
@normalizer
231
- def vdot (a : ArrayLike , b : ArrayLike , / ):
232
- # t_a, t_b = _helpers.to_tensors(a, b)
236
+ def vdot (a : ArrayLike , b : ArrayLike , / ):
233
237
result = _impl .vdot (a , b )
234
238
return result .item ()
235
239
236
240
237
241
@normalizer
238
- def dot (a : ArrayLike , b : ArrayLike , out = None ):
239
- # t_a, t_b = _helpers.to_tensors(a, b)
242
+ def dot (a : ArrayLike , b : ArrayLike , out = None ):
240
243
result = _impl .dot (a , b )
241
244
return _helpers .result_or_out (result , out )
242
245
@@ -245,19 +248,21 @@ def dot(a : ArrayLike, b : ArrayLike, out=None):
245
248
246
249
247
250
@normalizer
248
- def sort (a : ArrayLike , axis = - 1 , kind = None , order = None ):
251
+ def sort (a : ArrayLike , axis = - 1 , kind = None , order = None ):
249
252
result = _impl .sort (a , axis , kind , order )
250
253
return _helpers .array_from (result )
251
254
252
255
253
256
@normalizer
254
- def argsort (a : ArrayLike , axis = - 1 , kind = None , order = None ):
257
+ def argsort (a : ArrayLike , axis = - 1 , kind = None , order = None ):
255
258
result = _impl .argsort (a , axis , kind , order )
256
259
return _helpers .array_from (result )
257
260
258
261
259
262
@normalizer
260
- def searchsorted (a : ArrayLike , v : ArrayLike , side = "left" , sorter : Optional [ArrayLike ]= None ):
263
+ def searchsorted (
264
+ a : ArrayLike , v : ArrayLike , side = "left" , sorter : Optional [ArrayLike ] = None
265
+ ):
261
266
result = torch .searchsorted (a , v , side = side , sorter = sorter )
262
267
return _helpers .array_from (result )
263
268
@@ -266,19 +271,19 @@ def searchsorted(a : ArrayLike, v : ArrayLike, side="left", sorter : Optional[Ar
266
271
267
272
268
273
@normalizer
269
- def moveaxis (a : ArrayLike , source , destination ):
274
+ def moveaxis (a : ArrayLike , source , destination ):
270
275
result = _impl .moveaxis (a , source , destination )
271
276
return _helpers .array_from (result )
272
277
273
278
274
279
@normalizer
275
- def swapaxes (a : ArrayLike , axis1 , axis2 ):
280
+ def swapaxes (a : ArrayLike , axis1 , axis2 ):
276
281
result = _flips .swapaxes (a , axis1 , axis2 )
277
282
return _helpers .array_from (result )
278
283
279
284
280
285
@normalizer
281
- def rollaxis (a : ArrayLike , axis , start = 0 ):
286
+ def rollaxis (a : ArrayLike , axis , start = 0 ):
282
287
result = _flips .rollaxis (a , axis , start )
283
288
return _helpers .array_from (result )
284
289
@@ -287,32 +292,32 @@ def rollaxis(a : ArrayLike, axis, start=0):
287
292
288
293
289
294
@normalizer
290
- def squeeze (a : ArrayLike , axis = None ):
295
+ def squeeze (a : ArrayLike , axis = None ):
291
296
result = _impl .squeeze (a , axis )
292
297
return _helpers .array_from (result , a )
293
298
294
299
295
300
@normalizer
296
- def reshape (a : ArrayLike , newshape , order = "C" ):
301
+ def reshape (a : ArrayLike , newshape , order = "C" ):
297
302
result = _impl .reshape (a , newshape , order = order )
298
303
return _helpers .array_from (result , a )
299
304
300
305
301
306
@normalizer
302
- def transpose (a : ArrayLike , axes = None ):
307
+ def transpose (a : ArrayLike , axes = None ):
303
308
result = _impl .transpose (a , axes )
304
309
return _helpers .array_from (result , a )
305
310
306
311
307
312
@normalizer
308
- def ravel (a : ArrayLike , order = "C" ):
313
+ def ravel (a : ArrayLike , order = "C" ):
309
314
result = _impl .ravel (a )
310
315
return _helpers .array_from (result , a )
311
316
312
317
313
318
# leading underscore since arr.flatten exists but np.flatten does not
314
319
@normalizer
315
- def _flatten (a : ArrayLike , order = "C" ):
320
+ def _flatten (a : ArrayLike , order = "C" ):
316
321
result = _impl ._flatten (a )
317
322
return _helpers .array_from (result , a )
318
323
@@ -321,7 +326,7 @@ def _flatten(a : ArrayLike, order="C"):
321
326
322
327
323
328
@normalizer
324
- def real (a : ArrayLike ):
329
+ def real (a : ArrayLike ):
325
330
result = torch .real (a )
326
331
return _helpers .array_from (result )
327
332
@@ -333,7 +338,7 @@ def imag(a: ArrayLike):
333
338
334
339
335
340
@normalizer
336
- def round_ (a : ArrayLike , decimals = 0 , out = None ):
341
+ def round_ (a : ArrayLike , decimals = 0 , out = None ):
337
342
result = _impl .round (a , decimals )
338
343
return _helpers .result_or_out (result , out )
339
344
0 commit comments