Skip to content

Commit d923fc8

Browse files
committed
MAINT: make _detail a package, hide individual implementation files and import from _detail
1 parent 980999e commit d923fc8

File tree

8 files changed

+44
-47
lines changed

8 files changed

+44
-47
lines changed

torch_np/_detail/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._flips import *
2+
from ._reductions import *
3+
4+
# leading underscore (ndarray.flatten yes, np.flatten no)
5+
from .implementations import *
6+
from .implementations import _flatten

torch_np/_detail/_reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from . import _dtypes_impl, _util
1212

13-
NoValue = None
13+
NoValue = _util.NoValue
1414

1515

1616
import functools

torch_np/_detail/_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import _dtypes_impl
99

10+
NoValue = None
1011

1112
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
1213
def is_sequence(seq):

torch_np/_dtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __new__(self, value):
5151
#
5252
return _ndarray.ndarray(tensor)
5353

54+
5455
##### these are abstract types
5556

5657

torch_np/_funcs.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import torch
44

5+
from . import _detail as _impl
56
from . import _helpers
6-
from ._detail import _flips, _reductions, _util
7-
from ._detail import implementations as _impl
7+
from ._detail import _util
88
from ._normalizations import (
99
ArrayLike,
1010
AxisLike,
@@ -14,6 +14,8 @@
1414
normalizer,
1515
)
1616

17+
NoValue = _util.NoValue
18+
1719

1820
@normalizer
1921
def nonzero(a: ArrayLike):
@@ -158,13 +160,13 @@ def moveaxis(a: ArrayLike, source, destination):
158160

159161
@normalizer
160162
def swapaxes(a: ArrayLike, axis1, axis2):
161-
result = _flips.swapaxes(a, axis1, axis2)
163+
result = _impl.swapaxes(a, axis1, axis2)
162164
return _helpers.array_from(result)
163165

164166

165167
@normalizer
166168
def rollaxis(a: ArrayLike, axis, start=0):
167-
result = _flips.rollaxis(a, axis, start)
169+
result = _impl.rollaxis(a, axis, start)
168170
return _helpers.array_from(result)
169171

170172

@@ -230,9 +232,6 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
230232
# ### reductions ###
231233

232234

233-
NoValue = None # FIXME
234-
235-
236235
@normalizer
237236
def sum(
238237
a: ArrayLike,
@@ -243,7 +242,7 @@ def sum(
243242
initial=NoValue,
244243
where=NoValue,
245244
):
246-
result = _reductions.sum(
245+
result = _impl.sum(
247246
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
248247
)
249248
return _helpers.result_or_out(result, out)
@@ -259,7 +258,7 @@ def prod(
259258
initial=NoValue,
260259
where=NoValue,
261260
):
262-
result = _reductions.prod(
261+
result = _impl.prod(
263262
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
264263
)
265264
return _helpers.result_or_out(result, out)
@@ -278,9 +277,7 @@ def mean(
278277
*,
279278
where=NoValue,
280279
):
281-
result = _reductions.mean(
282-
a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims
283-
)
280+
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
284281
return _helpers.result_or_out(result, out)
285282

286283

@@ -295,7 +292,7 @@ def var(
295292
*,
296293
where=NoValue,
297294
):
298-
result = _reductions.var(
295+
result = _impl.var(
299296
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
300297
)
301298
return _helpers.result_or_out(result, out)
@@ -312,7 +309,7 @@ def std(
312309
*,
313310
where=NoValue,
314311
):
315-
result = _reductions.std(
312+
result = _impl.std(
316313
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
317314
)
318315
return _helpers.result_or_out(result, out)
@@ -326,7 +323,7 @@ def argmin(
326323
*,
327324
keepdims=NoValue,
328325
):
329-
result = _reductions.argmin(a, axis=axis, keepdims=keepdims)
326+
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
330327
return _helpers.result_or_out(result, out)
331328

332329

@@ -338,7 +335,7 @@ def argmax(
338335
*,
339336
keepdims=NoValue,
340337
):
341-
result = _reductions.argmax(a, axis=axis, keepdims=keepdims)
338+
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
342339
return _helpers.result_or_out(result, out)
343340

344341

@@ -351,9 +348,7 @@ def amax(
351348
initial=NoValue,
352349
where=NoValue,
353350
):
354-
result = _reductions.max(
355-
a, axis=axis, initial=initial, where=where, keepdims=keepdims
356-
)
351+
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
357352
return _helpers.result_or_out(result, out)
358353

359354

@@ -369,9 +364,7 @@ def amin(
369364
initial=NoValue,
370365
where=NoValue,
371366
):
372-
result = _reductions.min(
373-
a, axis=axis, initial=initial, where=where, keepdims=keepdims
374-
)
367+
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
375368
return _helpers.result_or_out(result, out)
376369

377370

@@ -382,7 +375,7 @@ def amin(
382375
def ptp(
383376
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
384377
):
385-
result = _reductions.ptp(a, axis=axis, keepdims=keepdims)
378+
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
386379
return _helpers.result_or_out(result, out)
387380

388381

@@ -395,7 +388,7 @@ def all(
395388
*,
396389
where=NoValue,
397390
):
398-
result = _reductions.all(a, axis=axis, where=where, keepdims=keepdims)
391+
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
399392
return _helpers.result_or_out(result, out)
400393

401394

@@ -408,13 +401,13 @@ def any(
408401
*,
409402
where=NoValue,
410403
):
411-
result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims)
404+
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
412405
return _helpers.result_or_out(result, out)
413406

414407

415408
@normalizer
416409
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
417-
result = _reductions.count_nonzero(a, axis=axis, keepdims=keepdims)
410+
result = _impl.count_nonzero(a, axis=axis, keepdims=keepdims)
418411
return _helpers.array_from(result)
419412

420413

@@ -425,7 +418,7 @@ def cumsum(
425418
dtype: DTypeLike = None,
426419
out: Optional[NDArray] = None,
427420
):
428-
result = _reductions.cumsum(a, axis=axis, dtype=dtype)
421+
result = _impl.cumsum(a, axis=axis, dtype=dtype)
429422
return _helpers.result_or_out(result, out)
430423

431424

@@ -436,7 +429,7 @@ def cumprod(
436429
dtype: DTypeLike = None,
437430
out: Optional[NDArray] = None,
438431
):
439-
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
432+
result = _impl.cumprod(a, axis=axis, dtype=dtype)
440433
return _helpers.result_or_out(result, out)
441434

442435

@@ -458,5 +451,5 @@ def quantile(
458451
if interpolation is not None:
459452
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
460453

461-
result = _reductions.quantile(a, q, axis, method=method, keepdims=keepdims)
454+
result = _impl.quantile(a, q, axis, method=method, keepdims=keepdims)
462455
return _helpers.result_or_out(result, out, promote_scalar=True)

torch_np/_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
8686

8787
def array_from(tensor, base=None):
8888
from ._ndarray import ndarray
89+
8990
return ndarray(tensor)
9091

9192

torch_np/_ndarray.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
6-
from ._detail import _dtypes_impl, _flips, _reductions, _util
6+
from ._detail import _dtypes_impl, _util
77
from ._detail import implementations as _impl
88

99
newaxis = None
@@ -93,7 +93,6 @@ def strides(self):
9393
def itemsize(self):
9494
return self.tensor.element_size()
9595

96-
9796
@property
9897
def flags(self):
9998
# Note contiguous in torch is assumed C-style
@@ -424,7 +423,6 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
424423
if isinstance(obj, ndarray):
425424
obj = obj.tensor
426425

427-
428426
# is a specific dtype requrested?
429427
torch_dtype = None
430428
if dtype is not None:
@@ -434,7 +432,6 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
434432
return ndarray(tensor)
435433

436434

437-
438435
def asarray(a, dtype=None, order=None, *, like=None):
439436
if order is None:
440437
order = "K"

torch_np/_wrapper.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88

99
import torch
1010

11-
from . import _funcs, _helpers
12-
from ._detail import _dtypes_impl, _flips, _reductions, _util
13-
from ._detail import implementations as _impl
11+
from . import _decorators
12+
from . import _detail as _impl
13+
from . import _dtypes, _funcs, _helpers
14+
from ._detail import _dtypes_impl, _util
1415
from ._ndarray import asarray
1516
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
1617

18+
NoValue = _util.NoValue
19+
1720
# Things to decide on (punt for now)
1821
#
1922
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
@@ -49,9 +52,6 @@
4952
# - optional out arg
5053

5154

52-
NoValue = None
53-
54-
5555
###### array creation routines
5656

5757

@@ -492,25 +492,25 @@ def expand_dims(a: ArrayLike, axis):
492492

493493
@normalizer
494494
def flip(m: ArrayLike, axis=None):
495-
result = _flips.flip(m, axis)
495+
result = _impl.flip(m, axis)
496496
return _helpers.array_from(result)
497497

498498

499499
@normalizer
500500
def flipud(m: ArrayLike):
501-
result = _flips.flipud(m)
501+
result = _impl.flipud(m)
502502
return _helpers.array_from(result)
503503

504504

505505
@normalizer
506506
def fliplr(m: ArrayLike):
507-
result = _flips.fliplr(m)
507+
result = _impl.fliplr(m)
508508
return _helpers.array_from(result)
509509

510510

511511
@normalizer
512512
def rot90(m: ArrayLike, k=1, axes=(0, 1)):
513-
result = _flips.rot90(m, k, axes)
513+
result = _impl.rot90(m, k, axes)
514514
return _helpers.array_from(result)
515515

516516

@@ -631,9 +631,7 @@ def average(
631631
*,
632632
keepdims=NoValue,
633633
):
634-
result, wsum = _reductions.average(
635-
a, axis, weights, returned=returned, keepdims=keepdims
636-
)
634+
result, wsum = _impl.average(a, axis, weights, returned=returned, keepdims=keepdims)
637635
if returned:
638636
return _helpers.tuple_arrays_from((result, wsum))
639637
else:

0 commit comments

Comments
 (0)