3
3
import torch
4
4
5
5
from . import _helpers
6
- from ._detail import _flips , _reductions , _util
7
- from ._detail import implementations as _impl
6
+ from . import _detail as _impl
7
+ from ._detail import _util
8
+
8
9
from ._normalizations import (
9
10
ArrayLike ,
10
11
AxisLike ,
16
17
)
17
18
18
19
20
+ NoValue = _util .NoValue
21
+
22
+
19
23
@normalizer
20
24
def nonzero (a : ArrayLike ):
21
25
result = a .nonzero (as_tuple = True )
@@ -159,13 +163,13 @@ def moveaxis(a: ArrayLike, source, destination):
159
163
160
164
@normalizer
161
165
def swapaxes (a : ArrayLike , axis1 , axis2 ):
162
- result = _flips .swapaxes (a , axis1 , axis2 )
166
+ result = _impl .swapaxes (a , axis1 , axis2 )
163
167
return _helpers .array_from (result )
164
168
165
169
166
170
@normalizer
167
171
def rollaxis (a : ArrayLike , axis , start = 0 ):
168
- result = _flips .rollaxis (a , axis , start )
172
+ result = _impl .rollaxis (a , axis , start )
169
173
return _helpers .array_from (result )
170
174
171
175
@@ -230,10 +234,6 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
230
234
231
235
# ### reductions ###
232
236
233
-
234
- NoValue = None # FIXME
235
-
236
-
237
237
@normalizer
238
238
def sum (
239
239
a : ArrayLike ,
@@ -244,7 +244,7 @@ def sum(
244
244
initial = NoValue ,
245
245
where = NoValue ,
246
246
):
247
- result = _reductions .sum (
247
+ result = _impl .sum (
248
248
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
249
249
)
250
250
return _helpers .result_or_out (result , out )
@@ -260,7 +260,7 @@ def prod(
260
260
initial = NoValue ,
261
261
where = NoValue ,
262
262
):
263
- result = _reductions .prod (
263
+ result = _impl .prod (
264
264
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
265
265
)
266
266
return _helpers .result_or_out (result , out )
@@ -279,7 +279,7 @@ def mean(
279
279
* ,
280
280
where = NoValue ,
281
281
):
282
- result = _reductions .mean (
282
+ result = _impl .mean (
283
283
a , axis = axis , dtype = dtype , where = NoValue , keepdims = keepdims
284
284
)
285
285
return _helpers .result_or_out (result , out )
@@ -296,7 +296,7 @@ def var(
296
296
* ,
297
297
where = NoValue ,
298
298
):
299
- result = _reductions .var (
299
+ result = _impl .var (
300
300
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
301
301
)
302
302
return _helpers .result_or_out (result , out )
@@ -313,7 +313,7 @@ def std(
313
313
* ,
314
314
where = NoValue ,
315
315
):
316
- result = _reductions .std (
316
+ result = _impl .std (
317
317
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
318
318
)
319
319
return _helpers .result_or_out (result , out )
@@ -327,7 +327,7 @@ def argmin(
327
327
* ,
328
328
keepdims = NoValue ,
329
329
):
330
- result = _reductions .argmin (a , axis = axis , keepdims = keepdims )
330
+ result = _impl .argmin (a , axis = axis , keepdims = keepdims )
331
331
return _helpers .result_or_out (result , out )
332
332
333
333
@@ -339,7 +339,7 @@ def argmax(
339
339
* ,
340
340
keepdims = NoValue ,
341
341
):
342
- result = _reductions .argmax (a , axis = axis , keepdims = keepdims )
342
+ result = _impl .argmax (a , axis = axis , keepdims = keepdims )
343
343
return _helpers .result_or_out (result , out )
344
344
345
345
@@ -352,7 +352,7 @@ def amax(
352
352
initial = NoValue ,
353
353
where = NoValue ,
354
354
):
355
- result = _reductions .max (
355
+ result = _impl .max (
356
356
a , axis = axis , initial = initial , where = where , keepdims = keepdims
357
357
)
358
358
return _helpers .result_or_out (result , out )
@@ -370,7 +370,7 @@ def amin(
370
370
initial = NoValue ,
371
371
where = NoValue ,
372
372
):
373
- result = _reductions .min (
373
+ result = _impl .min (
374
374
a , axis = axis , initial = initial , where = where , keepdims = keepdims
375
375
)
376
376
return _helpers .result_or_out (result , out )
@@ -383,7 +383,7 @@ def amin(
383
383
def ptp (
384
384
a : ArrayLike , axis : AxisLike = None , out : Optional [NDArray ] = None , keepdims = NoValue
385
385
):
386
- result = _reductions .ptp (a , axis = axis , keepdims = keepdims )
386
+ result = _impl .ptp (a , axis = axis , keepdims = keepdims )
387
387
return _helpers .result_or_out (result , out )
388
388
389
389
@@ -396,7 +396,7 @@ def all(
396
396
* ,
397
397
where = NoValue ,
398
398
):
399
- result = _reductions .all (a , axis = axis , where = where , keepdims = keepdims )
399
+ result = _impl .all (a , axis = axis , where = where , keepdims = keepdims )
400
400
return _helpers .result_or_out (result , out )
401
401
402
402
@@ -409,13 +409,13 @@ def any(
409
409
* ,
410
410
where = NoValue ,
411
411
):
412
- result = _reductions .any (a , axis = axis , where = where , keepdims = keepdims )
412
+ result = _impl .any (a , axis = axis , where = where , keepdims = keepdims )
413
413
return _helpers .result_or_out (result , out )
414
414
415
415
416
416
@normalizer
417
417
def count_nonzero (a : ArrayLike , axis : AxisLike = None , * , keepdims = False ):
418
- result = _reductions .count_nonzero (a , axis = axis , keepdims = keepdims )
418
+ result = _impl .count_nonzero (a , axis = axis , keepdims = keepdims )
419
419
return _helpers .array_from (result )
420
420
421
421
@@ -426,7 +426,7 @@ def cumsum(
426
426
dtype : DTypeLike = None ,
427
427
out : Optional [NDArray ] = None ,
428
428
):
429
- result = _reductions .cumsum (a , axis = axis , dtype = dtype )
429
+ result = _impl .cumsum (a , axis = axis , dtype = dtype )
430
430
return _helpers .result_or_out (result , out )
431
431
432
432
@@ -437,7 +437,7 @@ def cumprod(
437
437
dtype : DTypeLike = None ,
438
438
out : Optional [NDArray ] = None ,
439
439
):
440
- result = _reductions .cumprod (a , axis = axis , dtype = dtype )
440
+ result = _impl .cumprod (a , axis = axis , dtype = dtype )
441
441
return _helpers .result_or_out (result , out )
442
442
443
443
@@ -459,5 +459,5 @@ def quantile(
459
459
if interpolation is not None :
460
460
raise ValueError ("'interpolation' argument is deprecated; use 'method' instead" )
461
461
462
- result = _reductions .quantile (a , q , axis , method = method , keepdims = keepdims )
462
+ result = _impl .quantile (a , q , axis , method = method , keepdims = keepdims )
463
463
return _helpers .result_or_out (result , out , promote_scalar = True )
0 commit comments