2
2
3
3
import torch
4
4
5
+ from . import _detail as _impl
5
6
from . import _helpers
6
- from ._detail import _flips , _reductions , _util
7
- from ._detail import implementations as _impl
7
+ from ._detail import _util
8
8
from ._normalizations import (
9
9
ArrayLike ,
10
10
AxisLike ,
15
15
normalizer ,
16
16
)
17
17
18
+ NoValue = _util .NoValue
19
+
18
20
19
21
@normalizer
20
22
def nonzero (a : ArrayLike ):
@@ -159,13 +161,13 @@ def moveaxis(a: ArrayLike, source, destination):
159
161
160
162
@normalizer
161
163
def swapaxes (a : ArrayLike , axis1 , axis2 ):
162
- result = _flips .swapaxes (a , axis1 , axis2 )
164
+ result = _impl .swapaxes (a , axis1 , axis2 )
163
165
return _helpers .array_from (result )
164
166
165
167
166
168
@normalizer
167
169
def rollaxis (a : ArrayLike , axis , start = 0 ):
168
- result = _flips .rollaxis (a , axis , start )
170
+ result = _impl .rollaxis (a , axis , start )
169
171
return _helpers .array_from (result )
170
172
171
173
@@ -231,9 +233,6 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
231
233
# ### reductions ###
232
234
233
235
234
- NoValue = None # FIXME
235
-
236
-
237
236
@normalizer
238
237
def sum (
239
238
a : ArrayLike ,
@@ -244,7 +243,7 @@ def sum(
244
243
initial = NoValue ,
245
244
where = NoValue ,
246
245
):
247
- result = _reductions .sum (
246
+ result = _impl .sum (
248
247
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
249
248
)
250
249
return _helpers .result_or_out (result , out )
@@ -260,7 +259,7 @@ def prod(
260
259
initial = NoValue ,
261
260
where = NoValue ,
262
261
):
263
- result = _reductions .prod (
262
+ result = _impl .prod (
264
263
a , axis = axis , dtype = dtype , initial = initial , where = where , keepdims = keepdims
265
264
)
266
265
return _helpers .result_or_out (result , out )
@@ -279,9 +278,7 @@ def mean(
279
278
* ,
280
279
where = NoValue ,
281
280
):
282
- result = _reductions .mean (
283
- a , axis = axis , dtype = dtype , where = NoValue , keepdims = keepdims
284
- )
281
+ result = _impl .mean (a , axis = axis , dtype = dtype , where = NoValue , keepdims = keepdims )
285
282
return _helpers .result_or_out (result , out )
286
283
287
284
@@ -296,7 +293,7 @@ def var(
296
293
* ,
297
294
where = NoValue ,
298
295
):
299
- result = _reductions .var (
296
+ result = _impl .var (
300
297
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
301
298
)
302
299
return _helpers .result_or_out (result , out )
@@ -313,7 +310,7 @@ def std(
313
310
* ,
314
311
where = NoValue ,
315
312
):
316
- result = _reductions .std (
313
+ result = _impl .std (
317
314
a , axis = axis , dtype = dtype , ddof = ddof , where = where , keepdims = keepdims
318
315
)
319
316
return _helpers .result_or_out (result , out )
@@ -327,7 +324,7 @@ def argmin(
327
324
* ,
328
325
keepdims = NoValue ,
329
326
):
330
- result = _reductions .argmin (a , axis = axis , keepdims = keepdims )
327
+ result = _impl .argmin (a , axis = axis , keepdims = keepdims )
331
328
return _helpers .result_or_out (result , out )
332
329
333
330
@@ -339,7 +336,7 @@ def argmax(
339
336
* ,
340
337
keepdims = NoValue ,
341
338
):
342
- result = _reductions .argmax (a , axis = axis , keepdims = keepdims )
339
+ result = _impl .argmax (a , axis = axis , keepdims = keepdims )
343
340
return _helpers .result_or_out (result , out )
344
341
345
342
@@ -352,9 +349,7 @@ def amax(
352
349
initial = NoValue ,
353
350
where = NoValue ,
354
351
):
355
- result = _reductions .max (
356
- a , axis = axis , initial = initial , where = where , keepdims = keepdims
357
- )
352
+ result = _impl .max (a , axis = axis , initial = initial , where = where , keepdims = keepdims )
358
353
return _helpers .result_or_out (result , out )
359
354
360
355
@@ -370,9 +365,7 @@ def amin(
370
365
initial = NoValue ,
371
366
where = NoValue ,
372
367
):
373
- result = _reductions .min (
374
- a , axis = axis , initial = initial , where = where , keepdims = keepdims
375
- )
368
+ result = _impl .min (a , axis = axis , initial = initial , where = where , keepdims = keepdims )
376
369
return _helpers .result_or_out (result , out )
377
370
378
371
@@ -383,7 +376,7 @@ def amin(
383
376
def ptp (
384
377
a : ArrayLike , axis : AxisLike = None , out : Optional [NDArray ] = None , keepdims = NoValue
385
378
):
386
- result = _reductions .ptp (a , axis = axis , keepdims = keepdims )
379
+ result = _impl .ptp (a , axis = axis , keepdims = keepdims )
387
380
return _helpers .result_or_out (result , out )
388
381
389
382
@@ -396,7 +389,7 @@ def all(
396
389
* ,
397
390
where = NoValue ,
398
391
):
399
- result = _reductions .all (a , axis = axis , where = where , keepdims = keepdims )
392
+ result = _impl .all (a , axis = axis , where = where , keepdims = keepdims )
400
393
return _helpers .result_or_out (result , out )
401
394
402
395
@@ -409,13 +402,13 @@ def any(
409
402
* ,
410
403
where = NoValue ,
411
404
):
412
- result = _reductions .any (a , axis = axis , where = where , keepdims = keepdims )
405
+ result = _impl .any (a , axis = axis , where = where , keepdims = keepdims )
413
406
return _helpers .result_or_out (result , out )
414
407
415
408
416
409
@normalizer
417
410
def count_nonzero (a : ArrayLike , axis : AxisLike = None , * , keepdims = False ):
418
- result = _reductions .count_nonzero (a , axis = axis , keepdims = keepdims )
411
+ result = _impl .count_nonzero (a , axis = axis , keepdims = keepdims )
419
412
return _helpers .array_from (result )
420
413
421
414
@@ -426,7 +419,7 @@ def cumsum(
426
419
dtype : DTypeLike = None ,
427
420
out : Optional [NDArray ] = None ,
428
421
):
429
- result = _reductions .cumsum (a , axis = axis , dtype = dtype )
422
+ result = _impl .cumsum (a , axis = axis , dtype = dtype )
430
423
return _helpers .result_or_out (result , out )
431
424
432
425
@@ -437,7 +430,7 @@ def cumprod(
437
430
dtype : DTypeLike = None ,
438
431
out : Optional [NDArray ] = None ,
439
432
):
440
- result = _reductions .cumprod (a , axis = axis , dtype = dtype )
433
+ result = _impl .cumprod (a , axis = axis , dtype = dtype )
441
434
return _helpers .result_or_out (result , out )
442
435
443
436
@@ -459,5 +452,5 @@ def quantile(
459
452
if interpolation is not None :
460
453
raise ValueError ("'interpolation' argument is deprecated; use 'method' instead" )
461
454
462
- result = _reductions .quantile (a , q , axis , method = method , keepdims = keepdims )
455
+ result = _impl .quantile (a , q , axis , method = method , keepdims = keepdims )
463
456
return _helpers .result_or_out (result , out , promote_scalar = True )
0 commit comments