Skip to content

Commit 77b52f0

Browse files
committed
MAINT: make _detail a package, hide individual implementation files and import from _detail
1 parent 9cbbac6 commit 77b52f0

File tree

6 files changed

+42
-44
lines changed

6 files changed

+42
-44
lines changed

torch_np/_detail/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._flips import *
2+
from ._reductions import *
3+
4+
# leading underscore (ndarray.flatten yes, np.flatten no)
5+
from .implementations import *
6+
from .implementations import _flatten

torch_np/_detail/_reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from . import _dtypes_impl, _util
1212

13-
NoValue = None
13+
NoValue = _util.NoValue
1414

1515

1616
import functools

torch_np/_detail/_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import _dtypes_impl
99

10+
NoValue = None
1011

1112
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
1213
def is_sequence(seq):

torch_np/_funcs.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import torch
44

5+
from . import _detail as _impl
56
from . import _helpers
6-
from ._detail import _flips, _reductions, _util
7-
from ._detail import implementations as _impl
7+
from ._detail import _util
88
from ._normalizations import (
99
ArrayLike,
1010
AxisLike,
@@ -15,6 +15,8 @@
1515
normalizer,
1616
)
1717

18+
NoValue = _util.NoValue
19+
1820

1921
@normalizer
2022
def nonzero(a: ArrayLike):
@@ -159,13 +161,13 @@ def moveaxis(a: ArrayLike, source, destination):
159161

160162
@normalizer
161163
def swapaxes(a: ArrayLike, axis1, axis2):
162-
result = _flips.swapaxes(a, axis1, axis2)
164+
result = _impl.swapaxes(a, axis1, axis2)
163165
return _helpers.array_from(result)
164166

165167

166168
@normalizer
167169
def rollaxis(a: ArrayLike, axis, start=0):
168-
result = _flips.rollaxis(a, axis, start)
170+
result = _impl.rollaxis(a, axis, start)
169171
return _helpers.array_from(result)
170172

171173

@@ -231,9 +233,6 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
231233
# ### reductions ###
232234

233235

234-
NoValue = None # FIXME
235-
236-
237236
@normalizer
238237
def sum(
239238
a: ArrayLike,
@@ -244,7 +243,7 @@ def sum(
244243
initial=NoValue,
245244
where=NoValue,
246245
):
247-
result = _reductions.sum(
246+
result = _impl.sum(
248247
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
249248
)
250249
return _helpers.result_or_out(result, out)
@@ -260,7 +259,7 @@ def prod(
260259
initial=NoValue,
261260
where=NoValue,
262261
):
263-
result = _reductions.prod(
262+
result = _impl.prod(
264263
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
265264
)
266265
return _helpers.result_or_out(result, out)
@@ -279,9 +278,7 @@ def mean(
279278
*,
280279
where=NoValue,
281280
):
282-
result = _reductions.mean(
283-
a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims
284-
)
281+
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
285282
return _helpers.result_or_out(result, out)
286283

287284

@@ -296,7 +293,7 @@ def var(
296293
*,
297294
where=NoValue,
298295
):
299-
result = _reductions.var(
296+
result = _impl.var(
300297
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
301298
)
302299
return _helpers.result_or_out(result, out)
@@ -313,7 +310,7 @@ def std(
313310
*,
314311
where=NoValue,
315312
):
316-
result = _reductions.std(
313+
result = _impl.std(
317314
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
318315
)
319316
return _helpers.result_or_out(result, out)
@@ -327,7 +324,7 @@ def argmin(
327324
*,
328325
keepdims=NoValue,
329326
):
330-
result = _reductions.argmin(a, axis=axis, keepdims=keepdims)
327+
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
331328
return _helpers.result_or_out(result, out)
332329

333330

@@ -339,7 +336,7 @@ def argmax(
339336
*,
340337
keepdims=NoValue,
341338
):
342-
result = _reductions.argmax(a, axis=axis, keepdims=keepdims)
339+
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
343340
return _helpers.result_or_out(result, out)
344341

345342

@@ -352,9 +349,7 @@ def amax(
352349
initial=NoValue,
353350
where=NoValue,
354351
):
355-
result = _reductions.max(
356-
a, axis=axis, initial=initial, where=where, keepdims=keepdims
357-
)
352+
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
358353
return _helpers.result_or_out(result, out)
359354

360355

@@ -370,9 +365,7 @@ def amin(
370365
initial=NoValue,
371366
where=NoValue,
372367
):
373-
result = _reductions.min(
374-
a, axis=axis, initial=initial, where=where, keepdims=keepdims
375-
)
368+
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
376369
return _helpers.result_or_out(result, out)
377370

378371

@@ -383,7 +376,7 @@ def amin(
383376
def ptp(
384377
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
385378
):
386-
result = _reductions.ptp(a, axis=axis, keepdims=keepdims)
379+
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
387380
return _helpers.result_or_out(result, out)
388381

389382

@@ -396,7 +389,7 @@ def all(
396389
*,
397390
where=NoValue,
398391
):
399-
result = _reductions.all(a, axis=axis, where=where, keepdims=keepdims)
392+
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
400393
return _helpers.result_or_out(result, out)
401394

402395

@@ -409,13 +402,13 @@ def any(
409402
*,
410403
where=NoValue,
411404
):
412-
result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims)
405+
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
413406
return _helpers.result_or_out(result, out)
414407

415408

416409
@normalizer
417410
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
418-
result = _reductions.count_nonzero(a, axis=axis, keepdims=keepdims)
411+
result = _impl.count_nonzero(a, axis=axis, keepdims=keepdims)
419412
return _helpers.array_from(result)
420413

421414

@@ -426,7 +419,7 @@ def cumsum(
426419
dtype: DTypeLike = None,
427420
out: Optional[NDArray] = None,
428421
):
429-
result = _reductions.cumsum(a, axis=axis, dtype=dtype)
422+
result = _impl.cumsum(a, axis=axis, dtype=dtype)
430423
return _helpers.result_or_out(result, out)
431424

432425

@@ -437,7 +430,7 @@ def cumprod(
437430
dtype: DTypeLike = None,
438431
out: Optional[NDArray] = None,
439432
):
440-
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
433+
result = _impl.cumprod(a, axis=axis, dtype=dtype)
441434
return _helpers.result_or_out(result, out)
442435

443436

@@ -459,5 +452,5 @@ def quantile(
459452
if interpolation is not None:
460453
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
461454

462-
result = _reductions.quantile(a, q, axis, method=method, keepdims=keepdims)
455+
result = _impl.quantile(a, q, axis, method=method, keepdims=keepdims)
463456
return _helpers.result_or_out(result, out, promote_scalar=True)

torch_np/_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
6-
from ._detail import _dtypes_impl, _flips, _reductions, _util
6+
from ._detail import _dtypes_impl, _util
77
from ._detail import implementations as _impl
88

99
newaxis = None

torch_np/_wrapper.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
import torch
1010

11-
from . import _funcs, _helpers
12-
from ._detail import _dtypes_impl, _flips, _reductions, _util
13-
from ._detail import implementations as _impl
11+
from . import _decorators
12+
from . import _detail as _impl
13+
from . import _dtypes, _funcs, _helpers
14+
from ._detail import _dtypes_impl, _util
1415
from ._ndarray import asarray
1516
from ._normalizations import (
1617
ArrayLike,
@@ -21,6 +22,8 @@
2122
normalizer,
2223
)
2324

25+
NoValue = _util.NoValue
26+
2427
# Things to decide on (punt for now)
2528
#
2629
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
@@ -56,9 +59,6 @@
5659
# - optional out arg
5760

5861

59-
NoValue = None
60-
61-
6262
###### array creation routines
6363

6464

@@ -499,25 +499,25 @@ def expand_dims(a: ArrayLike, axis):
499499

500500
@normalizer
501501
def flip(m: ArrayLike, axis=None):
502-
result = _flips.flip(m, axis)
502+
result = _impl.flip(m, axis)
503503
return _helpers.array_from(result)
504504

505505

506506
@normalizer
507507
def flipud(m: ArrayLike):
508-
result = _flips.flipud(m)
508+
result = _impl.flipud(m)
509509
return _helpers.array_from(result)
510510

511511

512512
@normalizer
513513
def fliplr(m: ArrayLike):
514-
result = _flips.fliplr(m)
514+
result = _impl.fliplr(m)
515515
return _helpers.array_from(result)
516516

517517

518518
@normalizer
519519
def rot90(m: ArrayLike, k=1, axes=(0, 1)):
520-
result = _flips.rot90(m, k, axes)
520+
result = _impl.rot90(m, k, axes)
521521
return _helpers.array_from(result)
522522

523523

@@ -640,9 +640,7 @@ def average(
640640
*,
641641
keepdims=NoValue,
642642
):
643-
result, wsum = _reductions.average(
644-
a, axis, weights, returned=returned, keepdims=keepdims
645-
)
643+
result, wsum = _impl.average(a, axis, weights, returned=returned, keepdims=keepdims)
646644
if returned:
647645
return _helpers.tuple_arrays_from((result, wsum))
648646
else:

0 commit comments

Comments
 (0)