Skip to content

MAINT: make _detail a package, hide individual implementation files #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torch_np/_detail/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._flips import *
from ._reductions import *

# leading underscore (ndarray.flatten yes, np.flatten no)
from .implementations import *
from .implementations import _flatten
2 changes: 1 addition & 1 deletion torch_np/_detail/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from . import _dtypes_impl, _util

NoValue = None
NoValue = _util.NoValue


import functools
Expand Down
1 change: 1 addition & 0 deletions torch_np/_detail/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import _dtypes_impl

NoValue = None

# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
def is_sequence(seq):
Expand Down
51 changes: 22 additions & 29 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch

from . import _detail as _impl
from . import _helpers
from ._detail import _flips, _reductions, _util
from ._detail import implementations as _impl
from ._detail import _util
from ._normalizations import (
ArrayLike,
AxisLike,
Expand All @@ -14,6 +14,8 @@
normalizer,
)

NoValue = _util.NoValue


@normalizer
def nonzero(a: ArrayLike):
Expand Down Expand Up @@ -158,13 +160,13 @@ def moveaxis(a: ArrayLike, source, destination):

@normalizer
def swapaxes(a: ArrayLike, axis1, axis2):
result = _flips.swapaxes(a, axis1, axis2)
result = _impl.swapaxes(a, axis1, axis2)
return _helpers.array_from(result)


@normalizer
def rollaxis(a: ArrayLike, axis, start=0):
result = _flips.rollaxis(a, axis, start)
result = _impl.rollaxis(a, axis, start)
return _helpers.array_from(result)


Expand Down Expand Up @@ -230,9 +232,6 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
# ### reductions ###


NoValue = None # FIXME


@normalizer
def sum(
a: ArrayLike,
Expand All @@ -243,7 +242,7 @@ def sum(
initial=NoValue,
where=NoValue,
):
result = _reductions.sum(
result = _impl.sum(
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
Expand All @@ -259,7 +258,7 @@ def prod(
initial=NoValue,
where=NoValue,
):
result = _reductions.prod(
result = _impl.prod(
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
Expand All @@ -278,9 +277,7 @@ def mean(
*,
where=NoValue,
):
result = _reductions.mean(
a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims
)
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
return _helpers.result_or_out(result, out)


Expand All @@ -295,7 +292,7 @@ def var(
*,
where=NoValue,
):
result = _reductions.var(
result = _impl.var(
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
Expand All @@ -312,7 +309,7 @@ def std(
*,
where=NoValue,
):
result = _reductions.std(
result = _impl.std(
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
)
return _helpers.result_or_out(result, out)
Expand All @@ -326,7 +323,7 @@ def argmin(
*,
keepdims=NoValue,
):
result = _reductions.argmin(a, axis=axis, keepdims=keepdims)
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
return _helpers.result_or_out(result, out)


Expand All @@ -338,7 +335,7 @@ def argmax(
*,
keepdims=NoValue,
):
result = _reductions.argmax(a, axis=axis, keepdims=keepdims)
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
return _helpers.result_or_out(result, out)


Expand All @@ -351,9 +348,7 @@ def amax(
initial=NoValue,
where=NoValue,
):
result = _reductions.max(
a, axis=axis, initial=initial, where=where, keepdims=keepdims
)
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)


Expand All @@ -369,9 +364,7 @@ def amin(
initial=NoValue,
where=NoValue,
):
result = _reductions.min(
a, axis=axis, initial=initial, where=where, keepdims=keepdims
)
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)


Expand All @@ -382,7 +375,7 @@ def amin(
def ptp(
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
):
result = _reductions.ptp(a, axis=axis, keepdims=keepdims)
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
return _helpers.result_or_out(result, out)


Expand All @@ -395,7 +388,7 @@ def all(
*,
where=NoValue,
):
result = _reductions.all(a, axis=axis, where=where, keepdims=keepdims)
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)


Expand All @@ -408,13 +401,13 @@ def any(
*,
where=NoValue,
):
result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims)
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
return _helpers.result_or_out(result, out)


@normalizer
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False):
result = _reductions.count_nonzero(a, axis=axis, keepdims=keepdims)
result = _impl.count_nonzero(a, axis=axis, keepdims=keepdims)
return _helpers.array_from(result)


Expand All @@ -425,7 +418,7 @@ def cumsum(
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
):
result = _reductions.cumsum(a, axis=axis, dtype=dtype)
result = _impl.cumsum(a, axis=axis, dtype=dtype)
return _helpers.result_or_out(result, out)


Expand All @@ -436,7 +429,7 @@ def cumprod(
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
):
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
result = _impl.cumprod(a, axis=axis, dtype=dtype)
return _helpers.result_or_out(result, out)


Expand All @@ -458,5 +451,5 @@ def quantile(
if interpolation is not None:
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")

result = _reductions.quantile(a, q, axis, method=method, keepdims=keepdims)
result = _impl.quantile(a, q, axis, method=method, keepdims=keepdims)
return _helpers.result_or_out(result, out, promote_scalar=True)
2 changes: 1 addition & 1 deletion torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
from ._detail import _dtypes_impl, _flips, _reductions, _util
from ._detail import _dtypes_impl, _util
from ._detail import implementations as _impl

newaxis = None
Expand Down
24 changes: 11 additions & 13 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

import torch

from . import _funcs, _helpers
from ._detail import _dtypes_impl, _flips, _reductions, _util
from ._detail import implementations as _impl
from . import _decorators
from . import _detail as _impl
from . import _dtypes, _funcs, _helpers
from ._detail import _dtypes_impl, _util
from ._ndarray import asarray
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer

NoValue = _util.NoValue

# Things to decide on (punt for now)
#
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
Expand Down Expand Up @@ -49,9 +52,6 @@
# - optional out arg


NoValue = None


###### array creation routines


Expand Down Expand Up @@ -492,25 +492,25 @@ def expand_dims(a: ArrayLike, axis):

@normalizer
def flip(m: ArrayLike, axis=None):
result = _flips.flip(m, axis)
result = _impl.flip(m, axis)
return _helpers.array_from(result)


@normalizer
def flipud(m: ArrayLike):
result = _flips.flipud(m)
result = _impl.flipud(m)
return _helpers.array_from(result)


@normalizer
def fliplr(m: ArrayLike):
result = _flips.fliplr(m)
result = _impl.fliplr(m)
return _helpers.array_from(result)


@normalizer
def rot90(m: ArrayLike, k=1, axes=(0, 1)):
result = _flips.rot90(m, k, axes)
result = _impl.rot90(m, k, axes)
return _helpers.array_from(result)


Expand Down Expand Up @@ -631,9 +631,7 @@ def average(
*,
keepdims=NoValue,
):
result, wsum = _reductions.average(
a, axis, weights, returned=returned, keepdims=keepdims
)
result, wsum = _impl.average(a, axis, weights, returned=returned, keepdims=keepdims)
if returned:
return _helpers.tuple_arrays_from((result, wsum))
else:
Expand Down