2
2
3
3
import torch
4
4
5
+ from . import _detail as _impl
5
6
from . import _helpers
6
- from ._detail import _flips , _reductions , _util
7
- from ._detail import implementations as _impl
7
+ from ._detail import _util
8
8
from ._normalizations import (
9
9
ArrayLike ,
10
10
AxisLike ,
14
14
normalizer ,
15
15
)
16
16
17
+ NoValue = _util .NoValue
18
+
17
19
18
20
@normalizer
19
21
def nonzero (a : ArrayLike ):
@@ -158,13 +160,13 @@ def moveaxis(a: ArrayLike, source, destination):
158
160
159
161
@normalizer
160
162
def swapaxes (a : ArrayLike , axis1 , axis2 ):
161
- result = _flips .swapaxes (a , axis1 , axis2 )
163
+ result = _impl .swapaxes (a , axis1 , axis2 )
162
164
return _helpers .array_from (result )
163
165
164
166
165
167
@normalizer
166
168
def rollaxis (a : ArrayLike , axis , start = 0 ):
167
- result = _flips .rollaxis (a , axis , start )
169
+ result = _impl .rollaxis (a , axis , start )
168
170
return _helpers .array_from (result )
169
171
170
172
@@ -230,9 +232,6 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
230
232
# ### reductions ###
231
233
232
234
233
- NoValue = None # FIXME
234
-
235
-
236
235
@normalizer
237
236
def sum (
238
237
a : ArrayLike ,
@@ -243,7 +242,7 @@ def sum(
243
242
initial = NoValue ,
244
243
where = NoValue ,
245
244
):
246
- result = _reductions .sum (
245
+ result = _impl .sum (
247
246
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
248
247
)
249
248
return _helpers .result_or_out (result , out )
@@ -259,7 +258,7 @@ def prod(
259
258
initial = NoValue ,
260
259
where = NoValue ,
261
260
):
262
- result = _reductions .prod (
261
+ result = _impl .prod (
263
262
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
264
263
)
265
264
return _helpers .result_or_out (result , out )
@@ -278,9 +277,7 @@ def mean(
278
277
* ,
279
278
where = NoValue ,
280
279
):
281
- result = _reductions .mean (
282
- a , axis = axis , dtype = dtype , where = NoValue , keepdims = keepdims
283
- )
280
+ result = _impl .mean (a , axis = axis , dtype = dtype , where = NoValue , keepdims = keepdims )
284
281
return _helpers .result_or_out (result , out )
285
282
286
283
@@ -295,7 +292,7 @@ def var(
295
292
* ,
296
293
where = NoValue ,
297
294
):
298
- result = _reductions .var (
295
+ result = _impl .var (
299
296
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
300
297
)
301
298
return _helpers .result_or_out (result , out )
@@ -312,7 +309,7 @@ def std(
312
309
* ,
313
310
where = NoValue ,
314
311
):
315
- result = _reductions .std (
312
+ result = _impl .std (
316
313
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
317
314
)
318
315
return _helpers .result_or_out (result , out )
@@ -326,7 +323,7 @@ def argmin(
326
323
* ,
327
324
keepdims = NoValue ,
328
325
):
329
- result = _reductions .argmin (a , axis = axis , keepdims = keepdims )
326
+ result = _impl .argmin (a , axis = axis , keepdims = keepdims )
330
327
return _helpers .result_or_out (result , out )
331
328
332
329
@@ -338,7 +335,7 @@ def argmax(
338
335
* ,
339
336
keepdims = NoValue ,
340
337
):
341
- result = _reductions .argmax (a , axis = axis , keepdims = keepdims )
338
+ result = _impl .argmax (a , axis = axis , keepdims = keepdims )
342
339
return _helpers .result_or_out (result , out )
343
340
344
341
@@ -351,9 +348,7 @@ def amax(
351
348
initial = NoValue ,
352
349
where = NoValue ,
353
350
):
354
- result = _reductions .max (
355
- a , axis = axis , initial = initial , where = where , keepdims = keepdims
356
- )
351
+ result = _impl .max (a , axis = axis , initial = initial , where = where , keepdims = keepdims )
357
352
return _helpers .result_or_out (result , out )
358
353
359
354
@@ -369,9 +364,7 @@ def amin(
369
364
initial = NoValue ,
370
365
where = NoValue ,
371
366
):
372
- result = _reductions .min (
373
- a , axis = axis , initial = initial , where = where , keepdims = keepdims
374
- )
367
+ result = _impl .min (a , axis = axis , initial = initial , where = where , keepdims = keepdims )
375
368
return _helpers .result_or_out (result , out )
376
369
377
370
@@ -382,7 +375,7 @@ def amin(
382
375
def ptp (
383
376
a : ArrayLike , axis : AxisLike = None , out : Optional [NDArray ] = None , keepdims = NoValue
384
377
):
385
- result = _reductions .ptp (a , axis = axis , keepdims = keepdims )
378
+ result = _impl .ptp (a , axis = axis , keepdims = keepdims )
386
379
return _helpers .result_or_out (result , out )
387
380
388
381
@@ -395,7 +388,7 @@ def all(
395
388
* ,
396
389
where = NoValue ,
397
390
):
398
- result = _reductions .all (a , axis = axis , where = where , keepdims = keepdims )
391
+ result = _impl .all (a , axis = axis , where = where , keepdims = keepdims )
399
392
return _helpers .result_or_out (result , out )
400
393
401
394
@@ -408,13 +401,13 @@ def any(
408
401
* ,
409
402
where = NoValue ,
410
403
):
411
- result = _reductions .any (a , axis = axis , where = where , keepdims = keepdims )
404
+ result = _impl .any (a , axis = axis , where = where , keepdims = keepdims )
412
405
return _helpers .result_or_out (result , out )
413
406
414
407
415
408
@normalizer
416
409
def count_nonzero (a : ArrayLike , axis : AxisLike = None , * , keepdims = False ):
417
- result = _reductions .count_nonzero (a , axis = axis , keepdims = keepdims )
410
+ result = _impl .count_nonzero (a , axis = axis , keepdims = keepdims )
418
411
return _helpers .array_from (result )
419
412
420
413
@@ -425,7 +418,7 @@ def cumsum(
425
418
dtype : DTypeLike = None ,
426
419
out : Optional [NDArray ] = None ,
427
420
):
428
- result = _reductions .cumsum (a , axis = axis , dtype = dtype )
421
+ result = _impl .cumsum (a , axis = axis , dtype = dtype )
429
422
return _helpers .result_or_out (result , out )
430
423
431
424
@@ -436,7 +429,7 @@ def cumprod(
436
429
dtype : DTypeLike = None ,
437
430
out : Optional [NDArray ] = None ,
438
431
):
439
- result = _reductions .cumprod (a , axis = axis , dtype = dtype )
432
+ result = _impl .cumprod (a , axis = axis , dtype = dtype )
440
433
return _helpers .result_or_out (result , out )
441
434
442
435
@@ -458,5 +451,5 @@ def quantile(
458
451
if interpolation is not None :
459
452
raise ValueError ("'interpolation' argument is deprecated; use 'method' instead" )
460
453
461
- result = _reductions .quantile (a , q , axis , method = method , keepdims = keepdims )
454
+ result = _impl .quantile (a , q , axis , method = method , keepdims = keepdims )
462
455
return _helpers .result_or_out (result , out , promote_scalar = True )
0 commit comments