Skip to content

Commit 2884032

Browse files
committed
MAINT: make _detail a package, hide individual implementation files and import from _detail
1 parent 9cbbac6 commit 2884032

File tree

6 files changed

+44
-37
lines changed

6 files changed

+44
-37
lines changed

torch_np/_detail/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._flips import *
2+
from ._reductions import *
3+
from .implementations import *
4+
5+
# leading underscore (ndarray.flatten yes, np.flatten no)
6+
from .implementations import _flatten

torch_np/_detail/_reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from . import _dtypes_impl, _util
1212

13-
NoValue = None
13+
NoValue = _util.NoValue
1414

1515

1616
import functools

torch_np/_detail/_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from . import _dtypes_impl
99

1010

11+
NoValue = None
12+
1113
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
1214
def is_sequence(seq):
1315
if isinstance(seq, str):

torch_np/_funcs.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import torch
44

55
from . import _helpers
6-
from ._detail import _flips, _reductions, _util
7-
from ._detail import implementations as _impl
6+
from . import _detail as _impl
7+
from ._detail import _util
8+
89
from ._normalizations import (
910
ArrayLike,
1011
AxisLike,
@@ -16,6 +17,9 @@
1617
)
1718

1819

20+
NoValue = _util.NoValue
21+
22+
1923
@normalizer
2024
def nonzero(a: ArrayLike):
2125
result = a.nonzero(as_tuple=True)
@@ -159,13 +163,13 @@ def moveaxis(a: ArrayLike, source, destination):
159163

160164
@normalizer
161165
def swapaxes(a: ArrayLike, axis1, axis2):
162-
result = _flips.swapaxes(a, axis1, axis2)
166+
result = _impl.swapaxes(a, axis1, axis2)
163167
return _helpers.array_from(result)
164168

165169

166170
@normalizer
167171
def rollaxis(a: ArrayLike, axis, start=0):
168-
result = _flips.rollaxis(a, axis, start)
172+
result = _impl.rollaxis(a, axis, start)
169173
return _helpers.array_from(result)
170174

171175

@@ -230,10 +234,6 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
230234

231235
# ### reductions ###
232236

233-
234-
NoValue = None # FIXME
235-
236-
237237
@normalizer
238238
def sum(
239239
a: ArrayLike,
@@ -244,7 +244,7 @@ def sum(
244244
initial=NoValue,
245245
where=NoValue,
246246
):
247-
result = _reductions.sum(
247+
result = _impl.sum(
248248
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
249249
)
250250
return _helpers.result_or_out(result, out)
@@ -260,7 +260,7 @@ def prod(
260260
initial=NoValue,
261261
where=NoValue,
262262
):
263-
result = _reductions.prod(
263+
result = _impl.prod(
264264
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
265265
)
266266
return _helpers.result_or_out(result, out)
@@ -279,7 +279,7 @@ def mean(
279279
*,
280280
where=NoValue,
281281
):
282-
result = _reductions.mean(
282+
result = _impl.mean(
283283
a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims
284284
)
285285
return _helpers.result_or_out(result, out)
@@ -296,7 +296,7 @@ def var(
296296
*,
297297
where=NoValue,
298298
):
299-
result = _reductions.var(
299+
result = _impl.var(
300300
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
301301
)
302302
return _helpers.result_or_out(result, out)
@@ -313,7 +313,7 @@ def std(
313313
*,
314314
where=NoValue,
315315
):
316-
result = _reductions.std(
316+
result = _impl.std(
317317
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
318318
)
319319
return _helpers.result_or_out(result, out)
@@ -327,7 +327,7 @@ def argmin(
327327
*,
328328
keepdims=NoValue,
329329
):
330-
result = _reductions.argmin(a, axis=axis, keepdims=keepdims)
330+
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
331331
return _helpers.result_or_out(result, out)
332332

333333

@@ -339,7 +339,7 @@ def argmax(
339339
*,
340340
keepdims=NoValue,
341341
):
342-
result = _reductions.argmax(a, axis=axis, keepdims=keepdims)
342+
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
343343
return _helpers.result_or_out(result, out)
344344

345345

@@ -352,7 +352,7 @@ def amax(
352352
initial=NoValue,
353353
where=NoValue,
354354
):
355-
result = _reductions.max(
355+
result = _impl.max(
356356
a, axis=axis, initial=initial, where=where, keepdims=keepdims
357357
)
358358
return _helpers.result_or_out(result, out)
@@ -370,7 +370,7 @@ def amin(
370370
initial=NoValue,
371371
where=NoValue,
372372
):
373-
result = _reductions.min(
373+
result = _impl.min(
374374
a, axis=axis, initial=initial, where=where, keepdims=keepdims
375375
)
376376
return _helpers.result_or_out(result, out)
@@ -383,7 +383,7 @@ def amin(
383383
def ptp(
384384
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
385385
):
386-
result = _reductions.ptp(a, axis=axis, keepdims=keepdims)
386+
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
387387
return _helpers.result_or_out(result, out)
388388

389389

@@ -396,7 +396,7 @@ def all(
396396
*,
397397
where=NoValue,
398398
):
399-
result = _reductions.all(a, axis=axis, where=where, keepdims=keepdims)
399+
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
400400
return _helpers.result_or_out(result, out)
401401

402402

@@ -409,13 +409,13 @@ def any(
409409
*,
410410
where=NoValue,
411411
):
412-
result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims)
412+
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
413413
return _helpers.result_or_out(result, out)
414414

415415

416416
@normalizer
417417
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
418-
result = _reductions.count_nonzero(a, axis=axis, keepdims=keepdims)
418+
result = _impl.count_nonzero(a, axis=axis, keepdims=keepdims)
419419
return _helpers.array_from(result)
420420

421421

@@ -426,7 +426,7 @@ def cumsum(
426426
dtype: DTypeLike = None,
427427
out: Optional[NDArray] = None,
428428
):
429-
result = _reductions.cumsum(a, axis=axis, dtype=dtype)
429+
result = _impl.cumsum(a, axis=axis, dtype=dtype)
430430
return _helpers.result_or_out(result, out)
431431

432432

@@ -437,7 +437,7 @@ def cumprod(
437437
dtype: DTypeLike = None,
438438
out: Optional[NDArray] = None,
439439
):
440-
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
440+
result = _impl.cumprod(a, axis=axis, dtype=dtype)
441441
return _helpers.result_or_out(result, out)
442442

443443

@@ -459,5 +459,5 @@ def quantile(
459459
if interpolation is not None:
460460
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
461461

462-
result = _reductions.quantile(a, q, axis, method=method, keepdims=keepdims)
462+
result = _impl.quantile(a, q, axis, method=method, keepdims=keepdims)
463463
return _helpers.result_or_out(result, out, promote_scalar=True)

torch_np/_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
6-
from ._detail import _dtypes_impl, _flips, _reductions, _util
6+
from ._detail import _dtypes_impl, _util
77
from ._detail import implementations as _impl
88

99
newaxis = None

torch_np/_wrapper.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
import torch
1010

11-
from . import _funcs, _helpers
12-
from ._detail import _dtypes_impl, _flips, _reductions, _util
13-
from ._detail import implementations as _impl
11+
from . import _detail as _impl
12+
from . import _decorators, _dtypes, _funcs, _helpers
13+
from ._detail import _dtypes_impl, _util
1414
from ._ndarray import asarray
1515
from ._normalizations import (
1616
ArrayLike,
@@ -21,6 +21,8 @@
2121
normalizer,
2222
)
2323

24+
NoValue = _util.NoValue
25+
2426
# Things to decide on (punt for now)
2527
#
2628
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
@@ -56,9 +58,6 @@
5658
# - optional out arg
5759

5860

59-
NoValue = None
60-
61-
6261
###### array creation routines
6362

6463

@@ -499,25 +498,25 @@ def expand_dims(a: ArrayLike, axis):
499498

500499
@normalizer
501500
def flip(m: ArrayLike, axis=None):
502-
result = _flips.flip(m, axis)
501+
result = _impl.flip(m, axis)
503502
return _helpers.array_from(result)
504503

505504

506505
@normalizer
507506
def flipud(m: ArrayLike):
508-
result = _flips.flipud(m)
507+
result = _impl.flipud(m)
509508
return _helpers.array_from(result)
510509

511510

512511
@normalizer
513512
def fliplr(m: ArrayLike):
514-
result = _flips.fliplr(m)
513+
result = _impl.fliplr(m)
515514
return _helpers.array_from(result)
516515

517516

518517
@normalizer
519518
def rot90(m: ArrayLike, k=1, axes=(0, 1)):
520-
result = _flips.rot90(m, k, axes)
519+
result = _impl.rot90(m, k, axes)
521520
return _helpers.array_from(result)
522521

523522

@@ -640,7 +639,7 @@ def average(
640639
*,
641640
keepdims=NoValue,
642641
):
643-
result, wsum = _reductions.average(
642+
result, wsum = _impl.average(
644643
a, axis, weights, returned=returned, keepdims=keepdims
645644
)
646645
if returned:

0 commit comments

Comments
 (0)