diff --git a/torch_np/__init__.py b/torch_np/__init__.py index f48fcdfa..57cb6cd9 100644 --- a/torch_np/__init__.py +++ b/torch_np/__init__.py @@ -1,4 +1,3 @@ -from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports from . import random from ._binary_ufuncs import * from ._detail._index_tricks import * @@ -8,6 +7,7 @@ from ._getlimits import errstate, finfo, iinfo from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type from ._unary_ufuncs import * +from ._wrapper import * # from . import testing diff --git a/torch_np/_binary_ufuncs.py b/torch_np/_binary_ufuncs.py index 719771c5..9f2ca4a0 100644 --- a/torch_np/_binary_ufuncs.py +++ b/torch_np/_binary_ufuncs.py @@ -1,52 +1,52 @@ -from ._decorators import deco_binary_ufunc_from_impl -from ._detail import _ufunc_impl +from typing import Optional + +from . import _helpers +from ._detail import _binary_ufuncs +from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer + +__all__ = [ + name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch" +] + + +def deco_binary_ufunc(torch_func): + """Common infra for binary ufuncs. + + Normalize arguments, sort out type casting, broadcasting and delegate to + the pytorch functions for the actual work. + """ + + def wrapped( + x1: ArrayLike, + x2: ArrayLike, + /, + out: Optional[NDArray] = None, + *, + where=True, + casting="same_kind", + order="K", + dtype: DTypeLike = None, + subok: SubokLike = False, + signature=None, + extobj=None, + ): + tensors = _helpers.ufunc_preprocess( + (x1, x2), out, where, casting, order, dtype, subok, signature, extobj + ) + result = torch_func(*tensors) + return _helpers.result_or_out(result, out) + + return wrapped + # -# Functions in this file implement binary ufuncs: wrap two first arguments in -# asarray and delegate to functions from _ufunc_impl. -# -# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks -# with ufunc args, and delegate heavy lifting to pytorch equivalents. +# For each torch ufunc implementation, decorate and attach the decorated name +# to this module. Its contents is then exported to the public namespace in __init__.py # +for name in __all__: + ufunc = getattr(_binary_ufuncs, name) + decorated = normalizer(deco_binary_ufunc(ufunc)) -# the list is autogenerated, cf autogen/gen_ufunc_2.py -add = deco_binary_ufunc_from_impl(_ufunc_impl.add) -arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2) -bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and) -bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or) -bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor) -copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign) -divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide) -equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal) -float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power) -floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide) -fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax) -fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin) -fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod) -gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd) -greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater) -greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal) -heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside) -hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot) -lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm) -ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp) -left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift) -less = deco_binary_ufunc_from_impl(_ufunc_impl.less) -less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal) -logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp) -logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2) -logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and) -logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or) -logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor) -matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul) -maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum) -minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum) -remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder) -multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply) -nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter) -not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal) -power = deco_binary_ufunc_from_impl(_ufunc_impl.power) -remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder) -right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift) -subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract) -divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide) + decorated.__qualname__ = name # XXX: is this really correct? + decorated.__name__ = name + vars()[name] = decorated diff --git a/torch_np/_decorators.py b/torch_np/_decorators.py index 37c98bf6..c8542e1b 100644 --- a/torch_np/_decorators.py +++ b/torch_np/_decorators.py @@ -1,39 +1,10 @@ import functools -import operator import torch from . import _dtypes, _helpers from ._detail import _util -NoValue = None - - -def dtype_to_torch(func): - @functools.wraps(func) - def wrapped(*args, dtype=None, **kwds): - torch_dtype = None - if dtype is not None: - dtype = _dtypes.dtype(dtype) - torch_dtype = dtype.torch_dtype - return func(*args, dtype=torch_dtype, **kwds) - - return wrapped - - -def emulate_out_arg(func): - """Simulate the out=... handling: move the result tensor to the out array. - - With this decorator, the inner function just does not see the out array. - """ - - @functools.wraps(func) - def wrapped(*args, out=None, **kwds): - result_tensor = func(*args, **kwds) - return _helpers.result_or_out(result_tensor, out) - - return wrapped - def out_shape_dtype(func): """Handle out=... kwarg for ufuncs. @@ -51,89 +22,3 @@ def wrapped(*args, out=None, **kwds): return _helpers.result_or_out(result_tensor, out) return wrapped - - -def deco_unary_ufunc_from_impl(impl_func): - @functools.wraps(impl_func) - @dtype_to_torch - @out_shape_dtype - def wrapped(x1, *args, **kwds): - from ._ndarray import asarray - - x1_tensor = asarray(x1).get() - result = impl_func((x1_tensor,), *args, **kwds) - return result - - return wrapped - - -# TODO: deduplicate with _ndarray/asarray_replacer, -# and _wrapper/concatenate et al -def deco_binary_ufunc_from_impl(impl_func): - @functools.wraps(impl_func) - @dtype_to_torch - @out_shape_dtype - def wrapped(x1, x2, *args, **kwds): - from ._ndarray import asarray - - x1_tensor = asarray(x1).get() - x2_tensor = asarray(x2).get() - return impl_func((x1_tensor, x2_tensor), *args, **kwds) - - return wrapped - - -def axis_keepdims_wrapper(func): - """`func` accepts an array-like as a 1st arg, returns a tensor. - - This decorator implements the generic handling of axis, out and keepdims - arguments for reduction functions. - - Note that we peel off `out=...` and `keepdims=...` args (torch functions never - see them). The `axis` argument we normalize and pass through to pytorch functions. - - """ - # TODO: sort out function signatures: how they flow through all decorators etc - @functools.wraps(func) - def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds): - from ._ndarray import asarray, ndarray - - tensor = asarray(a).get() - - # standardize the axis argument - if isinstance(axis, ndarray): - axis = operator.index(axis) - - result = _util.axis_expand_func(func, tensor, axis, *args, **kwds) - - if keepdims: - result = _util.apply_keepdims(result, axis, tensor.ndim) - - return result - - return wrapped - - -def axis_none_ravel_wrapper(func): - """`func` accepts an array-like as a 1st arg, returns a tensor. - - This decorator implements the generic handling of axis=None acting on a - raveled array. One use is cumprod / cumsum. concatenate also uses a - similar logic. - - """ - - @functools.wraps(func) - def wrapped(a, axis=None, *args, **kwds): - from ._ndarray import asarray, ndarray - - tensor = asarray(a).get() - - # standardize the axis argument - if isinstance(axis, ndarray): - axis = operator.index(axis) - - result = _util.axis_ravel_func(func, tensor, axis, *args, **kwds) - return result - - return wrapped diff --git a/torch_np/_detail/_binary_ufuncs.py b/torch_np/_detail/_binary_ufuncs.py new file mode 100644 index 00000000..fbdff059 --- /dev/null +++ b/torch_np/_detail/_binary_ufuncs.py @@ -0,0 +1,53 @@ +"""Export torch work functions for binary ufuncs, rename/tweak to match numpy. +This listing is further exported to public symbols in the `torch_np/_binary_ufuncs.py` module. +""" + +import torch + +# renames +from torch import add, arctan2, bitwise_and +from torch import bitwise_left_shift as left_shift +from torch import bitwise_or +from torch import bitwise_right_shift as right_shift +from torch import bitwise_xor, copysign, divide +from torch import eq as equal +from torch import ( + float_power, + floor_divide, + fmax, + fmin, + fmod, + gcd, + greater, + greater_equal, + heaviside, + hypot, + lcm, + ldexp, + less, + less_equal, + logaddexp, + logaddexp2, + logical_and, + logical_or, + logical_xor, + maximum, + minimum, + multiply, + nextafter, + not_equal, +) +from torch import pow as power +from torch import remainder, subtract + +from . import _dtypes_impl, _util + + +# work around torch limitations w.r.t. numpy +def matmul(x, y): + # work around RuntimeError: expected scalar type Int but found Double + dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) + x = _util.cast_if_needed(x, dtype) + y = _util.cast_if_needed(y, dtype) + result = torch.matmul(x, y) + return result diff --git a/torch_np/_detail/_reductions.py b/torch_np/_detail/_reductions.py index a3d0a0b7..549f20e0 100644 --- a/torch_np/_detail/_reductions.py +++ b/torch_np/_detail/_reductions.py @@ -4,6 +4,8 @@ Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc """ +import typing + import torch from . import _dtypes_impl, _util @@ -11,6 +13,73 @@ NoValue = None +import functools + +############# XXX +### From _util.axis_expand_func + + +def deco_axis_expand(func): + """Generically handle axis arguments in reductions.""" + + @functools.wraps(func) + def wrapped(tensor, axis, *args, **kwds): + + if axis is not None: + if not isinstance(axis, (list, tuple)): + if not isinstance(axis, typing.SupportsIndex): + raise TypeError( + f"{type(axis)=}, but should be a list/tuple or support operator.index()" + ) + axis = (axis,) + axis = _util.normalize_axis_tuple(axis, tensor.ndim) + + if axis == (): + # NumPy does essentially an identity operation: + # >>> np.sum(np.ones(2), axis=()) + # array([1., 1.]) + # So we insert a length-one axis and run the reduction along it. + newshape = _util.expand_shape(tensor.shape, axis=0) + tensor = tensor.reshape(newshape) + axis = (0,) + + result = func(tensor, axis=axis, *args, **kwds) + return result + + return wrapped + + +def emulate_keepdims(func): + @functools.wraps(func) + def wrapped(tensor, axis=None, keepdims=NoValue, *args, **kwds): + result = func(tensor, axis=axis, *args, **kwds) + if keepdims: + result = _util.apply_keepdims(result, axis, tensor.ndim) + return result + + return wrapped + + +def deco_axis_ravel(func): + """Generically handle 'axis=None ravels' behavior.""" + + @functools.wraps(func) + def wrapped(tensor, axis, *args, **kwds): + if axis is not None: + axis = _util.normalize_axis_index(axis, tensor.ndim) + + tensors, axis = _util.axis_none_ravel(tensor, axis=axis) # XXX: inline + tensor = tensors[0] + + result = func(tensor, axis=axis, *args, **kwds) + return result + + return wrapped + + +##################################3 + + def _atleast_float(dtype, other_dtype): """Return a dtype that is real or complex floating-point. @@ -25,6 +94,8 @@ def _atleast_float(dtype, other_dtype): return dtype +@emulate_keepdims +@deco_axis_expand def count_nonzero(a, axis=None): # XXX: this all should probably be generalized to a sum(a != 0, dtype=bool) try: @@ -34,6 +105,8 @@ def count_nonzero(a, axis=None): return tensor +@emulate_keepdims +@deco_axis_expand def argmax(tensor, axis=None): axis = _util.allow_only_single_axis(axis) @@ -45,6 +118,8 @@ def argmax(tensor, axis=None): return tensor +@emulate_keepdims +@deco_axis_expand def argmin(tensor, axis=None): axis = _util.allow_only_single_axis(axis) @@ -56,6 +131,8 @@ def argmin(tensor, axis=None): return tensor +@emulate_keepdims +@deco_axis_expand def any(tensor, axis=None, *, where=NoValue): if where is not NoValue: raise NotImplementedError @@ -69,6 +146,8 @@ def any(tensor, axis=None, *, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def all(tensor, axis=None, *, where=NoValue): if where is not NoValue: raise NotImplementedError @@ -82,6 +161,8 @@ def all(tensor, axis=None, *, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def max(tensor, axis=None, initial=NoValue, where=NoValue): if initial is not NoValue or where is not NoValue: raise NotImplementedError @@ -90,6 +171,8 @@ def max(tensor, axis=None, initial=NoValue, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def min(tensor, axis=None, initial=NoValue, where=NoValue): if initial is not NoValue or where is not NoValue: raise NotImplementedError @@ -98,11 +181,15 @@ def min(tensor, axis=None, initial=NoValue, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def ptp(tensor, axis=None): result = tensor.amax(axis) - tensor.amin(axis) return result +@emulate_keepdims +@deco_axis_expand def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): if initial is not NoValue or where is not NoValue: raise NotImplementedError @@ -120,6 +207,8 @@ def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): if initial is not NoValue or where is not NoValue: raise NotImplementedError @@ -137,6 +226,8 @@ def prod(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def mean(tensor, axis=None, dtype=None, *, where=NoValue): if where is not NoValue: raise NotImplementedError @@ -159,6 +250,8 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): if where is not NoValue: raise NotImplementedError @@ -170,6 +263,8 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): return result +@emulate_keepdims +@deco_axis_expand def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): if where is not NoValue: raise NotImplementedError @@ -186,6 +281,7 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): # 2. axis=None ravels (cf concatenate) +@deco_axis_ravel def cumprod(tensor, axis, dtype=None): if dtype == torch.bool: dtype = _dtypes_impl.default_int_dtype @@ -197,6 +293,7 @@ def cumprod(tensor, axis, dtype=None): return result +@deco_axis_ravel def cumsum(tensor, axis, dtype=None): if dtype == torch.bool: dtype = _dtypes_impl.default_int_dtype @@ -208,7 +305,25 @@ def cumsum(tensor, axis, dtype=None): return result -def average(a_tensor, axis, w_tensor): +def average(a, axis, weights, returned=False, keepdims=False): + if weights is None: + result, wsum = average_noweights(a, axis, keepdims=keepdims) + else: + result, wsum = average_weights(a, axis, weights, keepdims=keepdims) + + if returned: + if wsum.shape != result.shape: + wsum = torch.broadcast_to(wsum, result.shape).clone() + return result, wsum + + +def average_noweights(a_tensor, axis, keepdims=False): + result = mean(a_tensor, axis=axis, keepdims=keepdims) + scl = torch.as_tensor(a_tensor.numel() / result.numel(), dtype=result.dtype) + return result, scl + + +def average_weights(a_tensor, axis, w_tensor, keepdims=False): # dtype # FIXME: 1. use result_type @@ -222,6 +337,9 @@ def average(a_tensor, axis, w_tensor): a_tensor = _util.cast_if_needed(a_tensor, result_dtype) w_tensor = _util.cast_if_needed(w_tensor, result_dtype) + # axis=None ravels, so store the originals to reuse with keepdims=True below + ax, ndim = axis, a_tensor.ndim + # axis if axis is None: (a_tensor, w_tensor), axis = _util.axis_none_ravel( @@ -250,10 +368,14 @@ def average(a_tensor, axis, w_tensor): denominator = w_tensor.sum(axis) result = numerator / denominator + # keepdims + if keepdims: + result = _util.apply_keepdims(result, ax, ndim) + return result, denominator -def quantile(a_tensor, q_tensor, axis, method): +def quantile(a_tensor, q_tensor, axis, method, keepdims=False): if (0 > q_tensor).any() or (q_tensor > 1).any(): raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor) @@ -266,6 +388,7 @@ def quantile(a_tensor, q_tensor, axis, method): if a_tensor.dtype == torch.float16: a_tensor = a_tensor.to(torch.float32) + # TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all. # axis if axis is not None: axis = _util.normalize_axis_tuple(axis, a_tensor.ndim) @@ -273,8 +396,15 @@ def quantile(a_tensor, q_tensor, axis, method): q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype) + # axis=None ravels, so store the originals to reuse with keepdims=True below + ax, ndim = axis, a_tensor.ndim (a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis) result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method) + # NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...) + # while the decorator expects (a, axis, ...) + # this can be fixed, of course, but the cure seems worse then the desease + if keepdims: + result = _util.apply_keepdims(result, ax, ndim) return result diff --git a/torch_np/_detail/_ufunc_impl.py b/torch_np/_detail/_ufunc_impl.py deleted file mode 100644 index 2cd1ecee..00000000 --- a/torch_np/_detail/_ufunc_impl.py +++ /dev/null @@ -1,158 +0,0 @@ -import torch - -from . import _dtypes_impl, _util - - -def deco_ufunc(torch_func): - """Common infra for binary ufuncs: receive tensors, sort out type casting, - broadcasting, and delegate to the pytorch function for actual work. - - - Converting array-likes into arrays, unwrapping them into tensors etc - is the caller responsibility. - """ - - def wrapped( - tensors, - /, - out_shape_dtype=None, - *, - where=True, - casting="same_kind", - order="K", - dtype=None, - subok=False, - **kwds, - ): - _util.subok_not_ok(subok=subok) - if order != "K" or not where: - raise NotImplementedError - - # XXX: dtype=... parameter - if dtype is not None: - raise NotImplementedError - - tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting) - - result = torch_func(*tensors) - return result - - return wrapped - - -# binary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py -# And edited manually! np.equal <--> torch.eq, not torch.equal -add = deco_ufunc(torch.add) -arctan2 = deco_ufunc(torch.arctan2) -bitwise_and = deco_ufunc(torch.bitwise_and) -bitwise_or = deco_ufunc(torch.bitwise_or) -bitwise_xor = deco_ufunc(torch.bitwise_xor) -copysign = deco_ufunc(torch.copysign) -divide = deco_ufunc(torch.divide) -equal = deco_ufunc(torch.eq) -float_power = deco_ufunc(torch.float_power) -floor_divide = deco_ufunc(torch.floor_divide) -fmax = deco_ufunc(torch.fmax) -fmin = deco_ufunc(torch.fmin) -fmod = deco_ufunc(torch.fmod) -gcd = deco_ufunc(torch.gcd) -greater = deco_ufunc(torch.greater) -greater_equal = deco_ufunc(torch.greater_equal) -heaviside = deco_ufunc(torch.heaviside) -hypot = deco_ufunc(torch.hypot) -lcm = deco_ufunc(torch.lcm) -ldexp = deco_ufunc(torch.ldexp) -left_shift = deco_ufunc(torch.bitwise_left_shift) -less = deco_ufunc(torch.less) -less_equal = deco_ufunc(torch.less_equal) -logaddexp = deco_ufunc(torch.logaddexp) -logaddexp2 = deco_ufunc(torch.logaddexp2) -logical_and = deco_ufunc(torch.logical_and) -logical_or = deco_ufunc(torch.logical_or) -logical_xor = deco_ufunc(torch.logical_xor) -maximum = deco_ufunc(torch.maximum) -minimum = deco_ufunc(torch.minimum) -remainder = deco_ufunc(torch.remainder) -multiply = deco_ufunc(torch.multiply) -nextafter = deco_ufunc(torch.nextafter) -not_equal = deco_ufunc(torch.not_equal) -power = deco_ufunc(torch.pow) -remainder = deco_ufunc(torch.remainder) -right_shift = deco_ufunc(torch.bitwise_right_shift) -subtract = deco_ufunc(torch.subtract) -divide = deco_ufunc(torch.divide) - - -# unary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py -arccos = deco_ufunc(torch.arccos) -arccosh = deco_ufunc(torch.arccosh) -arcsin = deco_ufunc(torch.arcsin) -arcsinh = deco_ufunc(torch.arcsinh) -arctan = deco_ufunc(torch.arctan) -arctanh = deco_ufunc(torch.arctanh) -ceil = deco_ufunc(torch.ceil) -conjugate = deco_ufunc(torch.conj_physical) -# conjugate = deco_ufunc(torch.conj_physical) -cos = deco_ufunc(torch.cos) -cosh = deco_ufunc(torch.cosh) -deg2rad = deco_ufunc(torch.deg2rad) -degrees = deco_ufunc(torch.rad2deg) -exp = deco_ufunc(torch.exp) -exp2 = deco_ufunc(torch.exp2) -expm1 = deco_ufunc(torch.expm1) -fabs = deco_ufunc(torch.absolute) -floor = deco_ufunc(torch.floor) -isfinite = deco_ufunc(torch.isfinite) -isinf = deco_ufunc(torch.isinf) -isnan = deco_ufunc(torch.isnan) -log = deco_ufunc(torch.log) -log10 = deco_ufunc(torch.log10) -log1p = deco_ufunc(torch.log1p) -log2 = deco_ufunc(torch.log2) -logical_not = deco_ufunc(torch.logical_not) -negative = deco_ufunc(torch.negative) -rad2deg = deco_ufunc(torch.rad2deg) -radians = deco_ufunc(torch.deg2rad) -reciprocal = deco_ufunc(torch.reciprocal) -rint = deco_ufunc(torch.round) -sign = deco_ufunc(torch.sign) -signbit = deco_ufunc(torch.signbit) -sin = deco_ufunc(torch.sin) -sinh = deco_ufunc(torch.sinh) -sqrt = deco_ufunc(torch.sqrt) -square = deco_ufunc(torch.square) -tan = deco_ufunc(torch.tan) -tanh = deco_ufunc(torch.tanh) -trunc = deco_ufunc(torch.trunc) - -invert = deco_ufunc(torch.bitwise_not) - -# special cases: torch does not export these names -def _cbrt(x): - return torch.pow(x, 1 / 3) - - -def _positive(x): - return +x - - -def _absolute(x): - # work around torch.absolute not impl for bools - if x.dtype == torch.bool: - return x - return torch.absolute(x) - - -def _matmul(x, y): - # work around RuntimeError: expected scalar type Int but found Double - dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) - x = _util.cast_if_needed(x, dtype) - y = _util.cast_if_needed(y, dtype) - result = torch.matmul(x, y) - return result - - -cbrt = deco_ufunc(_cbrt) -positive = deco_ufunc(_positive) -absolute = deco_ufunc(_absolute) -matmul = deco_ufunc(_matmul) diff --git a/torch_np/_detail/_unary_ufuncs.py b/torch_np/_detail/_unary_ufuncs.py new file mode 100644 index 00000000..e482e85f --- /dev/null +++ b/torch_np/_detail/_unary_ufuncs.py @@ -0,0 +1,55 @@ +"""Export torch work functions for unary ufuncs, rename/tweak to match numpy. +This listing is further exported to public symbols in the `torch_np/_unary_ufuncs.py` module. +""" + +import torch + +# renames +from torch import absolute as fabs +from torch import arccos, arccosh, arcsin, arcsinh, arctan, arctanh +from torch import bitwise_not as invert +from torch import ceil +from torch import conj_physical as conjugate +from torch import cos, cosh +from torch import deg2rad +from torch import deg2rad as radians +from torch import ( + exp, + exp2, + expm1, + floor, + isfinite, + isinf, + isnan, + log, + log1p, + log2, + log10, + logical_not, + negative, +) +from torch import rad2deg +from torch import rad2deg as degrees +from torch import reciprocal +from torch import round as rint +from torch import sign, signbit, sin, sinh, sqrt, square, tan, tanh, trunc + + +# special cases: torch does not export these names +def cbrt(x): + return torch.pow(x, 1 / 3) + + +def positive(x): + return +x + + +def absolute(x): + # work around torch.absolute not impl for bools + if x.dtype == torch.bool: + return x + return torch.absolute(x) + + +abs = absolute +conj = conjugate diff --git a/torch_np/_detail/_util.py b/torch_np/_detail/_util.py index 550f5492..47d4ae40 100644 --- a/torch_np/_detail/_util.py +++ b/torch_np/_detail/_util.py @@ -128,7 +128,7 @@ def apply_keepdims(tensor, axis, ndim): def axis_none_ravel(*tensors, axis=None): """Ravel the arrays if axis is none.""" - # XXX: is only used at `concatenate`. Inline unless reused more widely + # XXX: is only used at `concatenate` and cumsum/cumprod. Inline unless reused more widely if axis is None: tensors = tuple(ar.ravel() for ar in tensors) return tensors, 0 diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 315b28b0..627e93dd 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -261,12 +261,6 @@ def dsplit(tensor, indices_or_sections): def clip(tensor, t_min, t_max): - if t_min is not None: - t_min = torch.broadcast_to(t_min, tensor.shape) - - if t_max is not None: - t_max = torch.broadcast_to(t_max, tensor.shape) - if t_min is None and t_max is None: raise ValueError("One of max or min must be given") diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 9ebcd364..59dd594f 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -1,13 +1,23 @@ +from typing import Optional + import torch -from . import _decorators, _helpers -from ._detail import _dtypes_impl, _flips, _util +from . import _helpers +from ._detail import _flips, _reductions, _util from ._detail import implementations as _impl - - -def nonzero(a): - (tensor,) = _helpers.to_tensors(a) - result = tensor.nonzero(as_tuple=True) +from ._normalizations import ( + ArrayLike, + AxisLike, + DTypeLike, + NDArray, + SubokLike, + normalizer, +) + + +@normalizer +def nonzero(a: ArrayLike): + result = a.nonzero(as_tuple=True) return _helpers.tuple_arrays_from(result) @@ -17,62 +27,71 @@ def argwhere(a): return _helpers.array_from(result) -def clip(a, min=None, max=None, out=None): +@normalizer +def clip( + a: ArrayLike, + min: Optional[ArrayLike] = None, + max: Optional[ArrayLike] = None, + out: Optional[NDArray] = None, +): # np.clip requires both a_min and a_max not None, while ndarray.clip allows # one of them to be None. Follow the more lax version. - # Also min/max as arg names: follow numpy naming. - tensor, t_min, t_max = _helpers.to_tensors_or_none(a, min, max) - result = _impl.clip(tensor, t_min, t_max) + result = _impl.clip(a, min, max) return _helpers.result_or_out(result, out) -def repeat(a, repeats, axis=None): - tensor, t_repeats = _helpers.to_tensors(a, repeats) # XXX: scalar repeats - result = torch.repeat_interleave(tensor, t_repeats, axis) +@normalizer +def repeat(a: ArrayLike, repeats: ArrayLike, axis=None): + # XXX: scalar repeats; ArrayLikeOrScalar ? + result = torch.repeat_interleave(a, repeats, axis) return _helpers.array_from(result) # ### diag et al ### -def diagonal(a, offset=0, axis1=0, axis2=1): - (tensor,) = _helpers.to_tensors(a) - result = _impl.diagonal(tensor, offset, axis1, axis2) +@normalizer +def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1): + result = _impl.diagonal(a, offset, axis1, axis2) return _helpers.array_from(result) -@_decorators.dtype_to_torch -def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None): - (tensor,) = _helpers.to_tensors(a) - result = _impl.trace(tensor, offset, axis1, axis2, dtype) +@normalizer +def trace( + a: ArrayLike, + offset=0, + axis1=0, + axis2=1, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, +): + result = _impl.trace(a, offset, axis1, axis2, dtype) return _helpers.result_or_out(result, out) -@_decorators.dtype_to_torch -def eye(N, M=None, k=0, dtype=float, order="C", *, like=None): - _util.subok_not_ok(like) +@normalizer +def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None): if order != "C": raise NotImplementedError result = _impl.eye(N, M, k, dtype) return _helpers.array_from(result) -@_decorators.dtype_to_torch -def identity(n, dtype=None, *, like=None): - _util.subok_not_ok(like) +@normalizer +def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None): result = torch.eye(n, dtype=dtype) return _helpers.array_from(result) -def diag(v, k=0): - (tensor,) = _helpers.to_tensors(v) - result = torch.diag(tensor, k) +@normalizer +def diag(v: ArrayLike, k=0): + result = torch.diag(v, k) return _helpers.array_from(result) -def diagflat(v, k=0): - (tensor,) = _helpers.to_tensors(v) - result = torch.diagflat(tensor, k) +@normalizer +def diagflat(v: ArrayLike, k=0): + result = torch.diagflat(v, k) return _helpers.array_from(result) @@ -81,68 +100,70 @@ def diag_indices(n, ndim=2): return _helpers.tuple_arrays_from(result) -def diag_indices_from(arr): - (tensor,) = _helpers.to_tensors(arr) - result = _impl.diag_indices_from(tensor) +@normalizer +def diag_indices_from(arr: ArrayLike): + result = _impl.diag_indices_from(arr) return _helpers.tuple_arrays_from(result) -def fill_diagonal(a, val, wrap=False): - tensor, t_val = _helpers.to_tensors(a, val) - result = _impl.fill_diagonal(tensor, t_val, wrap) +@normalizer +def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False): + result = _impl.fill_diagonal(a, val, wrap) return _helpers.array_from(result) -def vdot(a, b, /): - t_a, t_b = _helpers.to_tensors(a, b) - result = _impl.vdot(t_a, t_b) +@normalizer +def vdot(a: ArrayLike, b: ArrayLike, /): + result = _impl.vdot(a, b) return result.item() -def dot(a, b, out=None): - t_a, t_b = _helpers.to_tensors(a, b) - result = _impl.dot(t_a, t_b) +@normalizer +def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None): + result = _impl.dot(a, b) return _helpers.result_or_out(result, out) # ### sort and partition ### -def sort(a, axis=-1, kind=None, order=None): - (tensor,) = _helpers.to_tensors(a) - result = _impl.sort(tensor, axis, kind, order) +@normalizer +def sort(a: ArrayLike, axis=-1, kind=None, order=None): + result = _impl.sort(a, axis, kind, order) return _helpers.array_from(result) -def argsort(a, axis=-1, kind=None, order=None): - (tensor,) = _helpers.to_tensors(a) - result = _impl.argsort(tensor, axis, kind, order) +@normalizer +def argsort(a: ArrayLike, axis=-1, kind=None, order=None): + result = _impl.argsort(a, axis, kind, order) return _helpers.array_from(result) -def searchsorted(a, v, side="left", sorter=None): - a_t, v_t, sorter_t = _helpers.to_tensors_or_none(a, v, sorter) - result = torch.searchsorted(a_t, v_t, side=side, sorter=sorter_t) +@normalizer +def searchsorted( + a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None +): + result = torch.searchsorted(a, v, side=side, sorter=sorter) return _helpers.array_from(result) # ### swap/move/roll axis ### -def moveaxis(a, source, destination): - (tensor,) = _helpers.to_tensors(a) - result = _impl.moveaxis(tensor, source, destination) +@normalizer +def moveaxis(a: ArrayLike, source, destination): + result = _impl.moveaxis(a, source, destination) return _helpers.array_from(result) -def swapaxes(a, axis1, axis2): - (tensor,) = _helpers.to_tensors(a) - result = _flips.swapaxes(tensor, axis1, axis2) +@normalizer +def swapaxes(a: ArrayLike, axis1, axis2): + result = _flips.swapaxes(a, axis1, axis2) return _helpers.array_from(result) -def rollaxis(a, axis, start=0): - (tensor,) = _helpers.to_tensors(a) +@normalizer +def rollaxis(a: ArrayLike, axis, start=0): result = _flips.rollaxis(a, axis, start) return _helpers.array_from(result) @@ -150,57 +171,292 @@ def rollaxis(a, axis, start=0): # ### shape manipulations ### -def squeeze(a, axis=None): - (tensor,) = _helpers.to_tensors(a) - result = _impl.squeeze(tensor, axis) +@normalizer +def squeeze(a: ArrayLike, axis=None): + result = _impl.squeeze(a, axis) return _helpers.array_from(result, a) -def reshape(a, newshape, order="C"): - (tensor,) = _helpers.to_tensors(a) - result = _impl.reshape(tensor, newshape, order=order) +@normalizer +def reshape(a: ArrayLike, newshape, order="C"): + result = _impl.reshape(a, newshape, order=order) return _helpers.array_from(result, a) -def transpose(a, axes=None): - (tensor,) = _helpers.to_tensors(a) - result = _impl.transpose(tensor, axes) +@normalizer +def transpose(a: ArrayLike, axes=None): + result = _impl.transpose(a, axes) return _helpers.array_from(result, a) -def ravel(a, order="C"): - (tensor,) = _helpers.to_tensors(a) - result = _impl.ravel(tensor) +@normalizer +def ravel(a: ArrayLike, order="C"): + result = _impl.ravel(a) return _helpers.array_from(result, a) # leading underscore since arr.flatten exists but np.flatten does not -def _flatten(a, order="C"): - (tensor,) = _helpers.to_tensors(a) - result = _impl._flatten(tensor) +@normalizer +def _flatten(a: ArrayLike, order="C"): + result = _impl._flatten(a) return _helpers.array_from(result, a) # ### Type/shape etc queries ### -def real(a): - (tensor,) = _helpers.to_tensors(a) - result = torch.real(tensor) +@normalizer +def real(a: ArrayLike): + result = torch.real(a) return _helpers.array_from(result) -def imag(a): - (tensor,) = _helpers.to_tensors(a) - result = _impl.imag(tensor) +@normalizer +def imag(a: ArrayLike): + result = _impl.imag(a) return _helpers.array_from(result) -def round_(a, decimals=0, out=None): - (tensor,) = _helpers.to_tensors(a) - result = _impl.round(tensor, decimals) +@normalizer +def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None): + result = _impl.round(a, decimals) return _helpers.result_or_out(result, out) around = round_ round = round_ + + +# ### reductions ### + + +NoValue = None # FIXME + + +@normalizer +def sum( + a: ArrayLike, + axis: AxisLike = None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + initial=NoValue, + where=NoValue, +): + result = _reductions.sum( + a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims + ) + return _helpers.result_or_out(result, out) + + +@normalizer +def prod( + a: ArrayLike, + axis: AxisLike = None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + initial=NoValue, + where=NoValue, +): + result = _reductions.prod( + a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims + ) + return _helpers.result_or_out(result, out) + + +product = prod + + +@normalizer +def mean( + a: ArrayLike, + axis: AxisLike = None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + *, + where=NoValue, +): + result = _reductions.mean( + a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims + ) + return _helpers.result_or_out(result, out) + + +@normalizer +def var( + a: ArrayLike, + axis: AxisLike = None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, + ddof=0, + keepdims=NoValue, + *, + where=NoValue, +): + result = _reductions.var( + a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims + ) + return _helpers.result_or_out(result, out) + + +@normalizer +def std( + a: ArrayLike, + axis: AxisLike = None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, + ddof=0, + keepdims=NoValue, + *, + where=NoValue, +): + result = _reductions.std( + a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims + ) + return _helpers.result_or_out(result, out) + + +@normalizer +def argmin( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[NDArray] = None, + *, + keepdims=NoValue, +): + result = _reductions.argmin(a, axis=axis, keepdims=keepdims) + return _helpers.result_or_out(result, out) + + +@normalizer +def argmax( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[NDArray] = None, + *, + keepdims=NoValue, +): + result = _reductions.argmax(a, axis=axis, keepdims=keepdims) + return _helpers.result_or_out(result, out) + + +@normalizer +def amax( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + initial=NoValue, + where=NoValue, +): + result = _reductions.max( + a, axis=axis, initial=initial, where=where, keepdims=keepdims + ) + return _helpers.result_or_out(result, out) + + +max = amax + + +@normalizer +def amin( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + initial=NoValue, + where=NoValue, +): + result = _reductions.min( + a, axis=axis, initial=initial, where=where, keepdims=keepdims + ) + return _helpers.result_or_out(result, out) + + +min = amin + + +@normalizer +def ptp( + a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue +): + result = _reductions.ptp(a, axis=axis, keepdims=keepdims) + return _helpers.result_or_out(result, out) + + +@normalizer +def all( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + *, + where=NoValue, +): + result = _reductions.all(a, axis=axis, where=where, keepdims=keepdims) + return _helpers.result_or_out(result, out) + + +@normalizer +def any( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + *, + where=NoValue, +): + result = _reductions.any(a, axis=axis, where=where, keepdims=keepdims) + return _helpers.result_or_out(result, out) + + +@normalizer +def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims=False): + result = _reductions.count_nonzero(a, axis=axis, keepdims=keepdims) + return _helpers.array_from(result) + + +@normalizer +def cumsum( + a: ArrayLike, + axis: AxisLike = None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, +): + result = _reductions.cumsum(a, axis=axis, dtype=dtype) + return _helpers.result_or_out(result, out) + + +@normalizer +def cumprod( + a: ArrayLike, + axis: AxisLike = None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, +): + result = _reductions.cumprod(a, axis=axis, dtype=dtype) + return _helpers.result_or_out(result, out) + + +cumproduct = cumprod + + +@normalizer +def quantile( + a: ArrayLike, + q: ArrayLike, + axis: AxisLike = None, + out: Optional[NDArray] = None, + overwrite_input=False, + method="linear", + keepdims=False, + *, + interpolation=None, +): + if interpolation is not None: + raise ValueError("'interpolation' argument is deprecated; use 'method' instead") + + result = _reductions.quantile(a, q, axis, method=method, keepdims=keepdims) + return _helpers.result_or_out(result, out, promote_scalar=True) diff --git a/torch_np/_helpers.py b/torch_np/_helpers.py index e40a6473..a894b3da 100644 --- a/torch_np/_helpers.py +++ b/torch_np/_helpers.py @@ -29,11 +29,6 @@ def cast_and_broadcast(tensors, out, casting): if out is None: return tensors else: - from ._ndarray import asarray, ndarray - - if not isinstance(out, ndarray): - raise TypeError("Return arrays must be of ArrayType") - tensors = _util.cast_and_broadcast( tensors, out.dtype.type.torch_dtype, out.shape, casting ) @@ -41,6 +36,26 @@ def cast_and_broadcast(tensors, out, casting): return tuple(tensors) +def ufunc_preprocess( + tensors, out, where, casting, order, dtype, subok, signature, extobj +): + # internal preprocessing or args in ufuncs (cf _unary_ufuncs, _binary_ufuncs) + if order != "K" or not where or signature or extobj: + raise NotImplementedError + + # XXX: dtype=... parameter + if dtype is not None: + raise NotImplementedError + + out_shape_dtype = None + if out is not None: + out_shape_dtype = (out.get().dtype, out.get().shape) + + tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting) + + return tensors + + # ### Return helpers: wrap a single tensor, a tuple of tensors, out= etc ### @@ -52,11 +67,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False): result_tensor is placed into the out array. This weirdness is used e.g. in `np.percentile` """ - from ._ndarray import asarray, ndarray - if out_array is not None: - if not isinstance(out_array, ndarray): - raise TypeError("Return arrays must be of ArrayType") if result_tensor.shape != out_array.shape: can_fit = result_tensor.numel() == 1 and out_array.ndim == 0 if promote_scalar and can_fit: @@ -70,7 +81,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False): out_tensor.copy_(result_tensor) return out_array else: - return asarray(result_tensor) + return array_from(result_tensor) def array_from(tensor, base=None): @@ -117,10 +128,3 @@ def to_tensors(*inputs): from ._ndarray import asarray, ndarray return tuple(asarray(value).get() for value in inputs) - - -def to_tensors_or_none(*inputs): - """Convert all array_likes from `inputs` to tensors. Nones pass through""" - from ._ndarray import asarray, ndarray - - return tuple(None if value is None else asarray(value).get() for value in inputs) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index ea3e0739..3c6f8000 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -1,16 +1,8 @@ -import functools import operator import torch from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs -from ._decorators import ( - NoValue, - axis_keepdims_wrapper, - axis_none_ravel_wrapper, - dtype_to_torch, - emulate_out_arg, -) from ._detail import _dtypes_impl, _flips, _reductions, _util from ._detail import implementations as _impl @@ -375,28 +367,23 @@ def sort(self, axis=-1, kind=None, order=None): searchsorted = _funcs.searchsorted ### reductions ### + argmax = _funcs.argmax + argmin = _funcs.argmin - argmin = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmin)) - argmax = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmax)) + any = _funcs.any + all = _funcs.all + max = _funcs.max + min = _funcs.min + ptp = _funcs.ptp - any = emulate_out_arg(axis_keepdims_wrapper(_reductions.any)) - all = emulate_out_arg(axis_keepdims_wrapper(_reductions.all)) - max = emulate_out_arg(axis_keepdims_wrapper(_reductions.max)) - min = emulate_out_arg(axis_keepdims_wrapper(_reductions.min)) - ptp = emulate_out_arg(axis_keepdims_wrapper(_reductions.ptp)) + sum = _funcs.sum + prod = _funcs.prod + mean = _funcs.mean + var = _funcs.var + std = _funcs.std - sum = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.sum))) - prod = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.prod))) - mean = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.mean))) - var = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.var))) - std = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.std))) - - cumprod = emulate_out_arg( - axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumprod)) - ) - cumsum = emulate_out_arg( - axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumsum)) - ) + cumsum = _funcs.cumsum + cumprod = _funcs.cumprod ### indexing ### @staticmethod @@ -470,25 +457,6 @@ def maybe_set_base(tensor, base): return ndarray._from_tensor_and_base(tensor, base) -class asarray_replacer: - def __init__(self, dispatch="one"): - if dispatch not in ["one", "two"]: - raise ValueError("ararray_replacer: unknown dispatch %s" % dispatch) - self._dispatch = dispatch - - def __call__(self, func): - if self._dispatch == "one": - - @functools.wraps(func) - def wrapped(x, *args, **kwds): - x_tensor = asarray(x).get() - return asarray(func(x_tensor, *args, **kwds)) - - return wrapped - else: - raise ValueError - - ###### dtype routines diff --git a/torch_np/_normalizations.py b/torch_np/_normalizations.py new file mode 100644 index 00000000..7f2ace98 --- /dev/null +++ b/torch_np/_normalizations.py @@ -0,0 +1,115 @@ +""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on. +""" +import operator +import typing +from typing import Optional, Sequence + +import torch + +from . import _helpers + +ArrayLike = typing.TypeVar("ArrayLike") +DTypeLike = typing.TypeVar("DTypeLike") +SubokLike = typing.TypeVar("SubokLike") +AxisLike = typing.TypeVar("AxisLike") +NDArray = typing.TypeVar("NDarray") + + +import inspect + +from . import _dtypes + + +def normalize_array_like(x, name=None): + (tensor,) = _helpers.to_tensors(x) + return tensor + + +def normalize_optional_array_like(x, name=None): + # This explicit normalizer is needed because otherwise normalize_array_like + # does not run for a parameter annotated as Optional[ArrayLike] + return None if x is None else normalize_array_like(x, name) + + +def normalize_seq_array_like(x, name=None): + tensors = _helpers.to_tensors(*x) + return tensors + + +def normalize_dtype(dtype, name=None): + # cf _decorators.dtype_to_torch + torch_dtype = None + if dtype is not None: + dtype = _dtypes.dtype(dtype) + torch_dtype = dtype.torch_dtype + return torch_dtype + + +def normalize_subok_like(arg, name): + if arg: + raise ValueError(f"'{name}' parameter is not supported.") + + +def normalize_axis_like(arg, name=None): + from ._ndarray import ndarray + + if isinstance(arg, ndarray): + arg = operator.index(arg) + return arg + + +def normalize_ndarray(arg, name=None): + if arg is None: + return arg + + from ._ndarray import ndarray + + if not isinstance(arg, ndarray): + raise TypeError("'out' must be an array") + return arg + + +normalizers = { + ArrayLike: normalize_array_like, + Optional[ArrayLike]: normalize_optional_array_like, + Sequence[ArrayLike]: normalize_seq_array_like, + Optional[NDArray]: normalize_ndarray, + DTypeLike: normalize_dtype, + SubokLike: normalize_subok_like, + AxisLike: normalize_axis_like, +} + +import functools + + +def maybe_normalize(arg, parm): + """Normalize arg if a normalizer is registred.""" + normalizer = normalizers.get(parm.annotation, None) + return normalizer(arg) if normalizer else arg + + +def normalizer(func): + @functools.wraps(func) + def wrapped(*args, **kwds): + params = inspect.signature(func).parameters + first_param = next(iter(params.values())) + # NumPy's API does not have positional args before variadic positional args + if first_param.kind == inspect.Parameter.VAR_POSITIONAL: + args = [maybe_normalize(arg, first_param) for arg in args] + else: + # NB: extra unknown arguments: pass through, will raise in func(*args) below + args = ( + tuple( + maybe_normalize(arg, parm) + for arg, parm in zip(args, params.values()) + ) + + args[len(params.values()) :] + ) + + kwds = { + name: maybe_normalize(arg, params[name]) if name in params else arg + for name, arg in kwds.items() + } + return func(*args, **kwds) + + return wrapped diff --git a/torch_np/_unary_ufuncs.py b/torch_np/_unary_ufuncs.py index e50e96fb..8990743b 100644 --- a/torch_np/_unary_ufuncs.py +++ b/torch_np/_unary_ufuncs.py @@ -1,103 +1,55 @@ -from ._decorators import deco_unary_ufunc_from_impl -from ._detail import _ufunc_impl +# from ._decorators import deco_unary_ufunc_from_impl +# from ._detail import _ufunc_impl + + +from typing import Optional + +from . import _helpers +from ._detail import _unary_ufuncs +from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer __all__ = [ - "abs", - "absolute", - "arccos", - "arccosh", - "arcsin", - "arcsinh", - "arctan", - "arctanh", - "cbrt", - "ceil", - "conj", - "conjugate", - "cos", - "cosh", - "deg2rad", - "degrees", - "exp", - "exp2", - "expm1", - "fabs", - "floor", - "isfinite", - "isinf", - "isnan", - "log", - "log10", - "log1p", - "log2", - "logical_not", - "negative", - "positive", - "rad2deg", - "radians", - "reciprocal", - "rint", - "sign", - "signbit", - "sin", - "sinh", - "sqrt", - "square", - "tan", - "tanh", - "trunc", - "invert", + name for name in dir(_unary_ufuncs) if not name.startswith("_") and name != "torch" ] -absolute = deco_unary_ufunc_from_impl(_ufunc_impl.absolute) -arccos = deco_unary_ufunc_from_impl(_ufunc_impl.arccos) -arccosh = deco_unary_ufunc_from_impl(_ufunc_impl.arccosh) -arcsin = deco_unary_ufunc_from_impl(_ufunc_impl.arcsin) -arcsinh = deco_unary_ufunc_from_impl(_ufunc_impl.arcsinh) -arctan = deco_unary_ufunc_from_impl(_ufunc_impl.arctan) -arctanh = deco_unary_ufunc_from_impl(_ufunc_impl.arctanh) -ceil = deco_unary_ufunc_from_impl(_ufunc_impl.ceil) -conjugate = deco_unary_ufunc_from_impl(_ufunc_impl.conjugate) -cos = deco_unary_ufunc_from_impl(_ufunc_impl.cos) -cosh = deco_unary_ufunc_from_impl(_ufunc_impl.cosh) -deg2rad = deco_unary_ufunc_from_impl(_ufunc_impl.deg2rad) -degrees = deco_unary_ufunc_from_impl(_ufunc_impl.rad2deg) -exp = deco_unary_ufunc_from_impl(_ufunc_impl.exp) -exp2 = deco_unary_ufunc_from_impl(_ufunc_impl.exp2) -expm1 = deco_unary_ufunc_from_impl(_ufunc_impl.expm1) -fabs = deco_unary_ufunc_from_impl(_ufunc_impl.absolute) -floor = deco_unary_ufunc_from_impl(_ufunc_impl.floor) -isfinite = deco_unary_ufunc_from_impl(_ufunc_impl.isfinite) -isinf = deco_unary_ufunc_from_impl(_ufunc_impl.isinf) -isnan = deco_unary_ufunc_from_impl(_ufunc_impl.isnan) -log = deco_unary_ufunc_from_impl(_ufunc_impl.log) -log10 = deco_unary_ufunc_from_impl(_ufunc_impl.log10) -log1p = deco_unary_ufunc_from_impl(_ufunc_impl.log1p) -log2 = deco_unary_ufunc_from_impl(_ufunc_impl.log2) -logical_not = deco_unary_ufunc_from_impl(_ufunc_impl.logical_not) -negative = deco_unary_ufunc_from_impl(_ufunc_impl.negative) -rad2deg = deco_unary_ufunc_from_impl(_ufunc_impl.rad2deg) -radians = deco_unary_ufunc_from_impl(_ufunc_impl.deg2rad) -reciprocal = deco_unary_ufunc_from_impl(_ufunc_impl.reciprocal) -rint = deco_unary_ufunc_from_impl(_ufunc_impl.rint) -sign = deco_unary_ufunc_from_impl(_ufunc_impl.sign) -signbit = deco_unary_ufunc_from_impl(_ufunc_impl.signbit) -sin = deco_unary_ufunc_from_impl(_ufunc_impl.sin) -sinh = deco_unary_ufunc_from_impl(_ufunc_impl.sinh) -sqrt = deco_unary_ufunc_from_impl(_ufunc_impl.sqrt) -square = deco_unary_ufunc_from_impl(_ufunc_impl.square) -tan = deco_unary_ufunc_from_impl(_ufunc_impl.tan) -tanh = deco_unary_ufunc_from_impl(_ufunc_impl.tanh) -trunc = deco_unary_ufunc_from_impl(_ufunc_impl.trunc) +def deco_unary_ufunc(torch_func): + """Common infra for unary ufuncs. + + Normalize arguments, sort out type casting, broadcasting and delegate to + the pytorch functions for the actual work. + """ + + def wrapped( + x: ArrayLike, + /, + out: Optional[NDArray] = None, + *, + where=True, + casting="same_kind", + order="K", + dtype: DTypeLike = None, + subok: SubokLike = False, + signature=None, + extobj=None, + ): + tensors = _helpers.ufunc_preprocess( + (x,), out, where, casting, order, dtype, subok, signature, extobj + ) + result = torch_func(*tensors) + return _helpers.result_or_out(result, out) -invert = deco_unary_ufunc_from_impl(_ufunc_impl.invert) + return wrapped -cbrt = deco_unary_ufunc_from_impl(_ufunc_impl.cbrt) -positive = deco_unary_ufunc_from_impl(_ufunc_impl.positive) +# +# For each torch ufunc implementation, decorate and attach the decorated name +# to this module. Its contents is then exported to the public namespace in __init__.py +# +for name in __all__: + ufunc = getattr(_unary_ufuncs, name) + decorated = normalizer(deco_unary_ufunc(ufunc)) -# numpy has these aliases while torch does not -abs = absolute -conj = conjugate -bitwise_not = invert + decorated.__qualname__ = name # XXX: is this really correct? + decorated.__name__ = name + vars()[name] = decorated diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index dbf0bf9a..39bff119 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -4,14 +4,15 @@ pytorch tensors. """ +from typing import Optional, Sequence + import torch -from . import _funcs +from . import _decorators, _dtypes, _funcs, _helpers from ._detail import _dtypes_impl, _flips, _reductions, _util from ._detail import implementations as _impl -from ._ndarray import array, asarray, asarray_replacer, maybe_set_base, ndarray, newaxis - -from . import _dtypes, _helpers, _decorators # isort: skip # XXX +from ._ndarray import array, asarray, maybe_set_base, ndarray +from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer # Things to decide on (punt for now) # @@ -54,40 +55,39 @@ ###### array creation routines -def copy(a, order="K", subok=False): - a = asarray(a) - _util.subok_not_ok(subok=subok) +@normalizer +def copy(a: ArrayLike, order="K", subok: SubokLike = False): if order != "K": raise NotImplementedError - # XXX: ndarray.copy only accepts order='C' - return a.copy(order="C") + tensor = a.clone() + return _helpers.array_from(tensor) -def atleast_1d(*arys): - tensors = _helpers.to_tensors(*arys) - res = torch.atleast_1d(tensors) - if len(res) == 1: - return asarray(res[0]) +@normalizer +def atleast_1d(*arys: ArrayLike): + res = torch.atleast_1d(*arys) + if isinstance(res, tuple): + return list(_helpers.tuple_arrays_from(res)) else: - return list(asarray(_) for _ in res) + return _helpers.array_from(res) -def atleast_2d(*arys): - tensors = _helpers.to_tensors(*arys) - res = torch.atleast_2d(tensors) - if len(res) == 1: - return asarray(res[0]) +@normalizer +def atleast_2d(*arys: ArrayLike): + res = torch.atleast_2d(*arys) + if isinstance(res, tuple): + return list(_helpers.tuple_arrays_from(res)) else: - return list(asarray(_) for _ in res) + return _helpers.array_from(res) -def atleast_3d(*arys): - tensors = _helpers.to_tensors(*arys) - res = torch.atleast_3d(tensors) - if len(res) == 1: - return asarray(res[0]) +@normalizer +def atleast_3d(*arys: ArrayLike): + res = torch.atleast_3d(*arys) + if isinstance(res, tuple): + return list(_helpers.tuple_arrays_from(res)) else: - return list(asarray(_) for _ in res) + return _helpers.array_from(res) def _concat_check(tup, dtype, out): @@ -97,9 +97,6 @@ def _concat_check(tup, dtype, out): raise ValueError("need at least one array to concatenate") if out is not None: - if not isinstance(out, ndarray): - raise ValueError("'out' must be an array") - if dtype is not None: # mimic numpy raise TypeError( @@ -108,58 +105,67 @@ def _concat_check(tup, dtype, out): ) -@_decorators.dtype_to_torch -def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"): - tensors = _helpers.to_tensors(*ar_tuple) - _concat_check(tensors, dtype, out=out) - result = _impl.concatenate(tensors, axis, out, dtype, casting) +@normalizer +def concatenate( + ar_tuple: Sequence[ArrayLike], + axis=0, + out: Optional[NDArray] = None, + dtype: DTypeLike = None, + casting="same_kind", +): + _concat_check(ar_tuple, dtype, out=out) + result = _impl.concatenate(ar_tuple, axis, out, dtype, casting) return _helpers.result_or_out(result, out) -@_decorators.dtype_to_torch -def vstack(tup, *, dtype=None, casting="same_kind"): - tensors = _helpers.to_tensors(*tup) - _concat_check(tensors, dtype, out=None) - result = _impl.vstack(tensors, dtype=dtype, casting=casting) - return asarray(result) +@normalizer +def vstack(tup: Sequence[ArrayLike], *, dtype: DTypeLike = None, casting="same_kind"): + _concat_check(tup, dtype, out=None) + result = _impl.vstack(tup, dtype=dtype, casting=casting) + return _helpers.array_from(result) row_stack = vstack -@_decorators.dtype_to_torch -def hstack(tup, *, dtype=None, casting="same_kind"): - tensors = _helpers.to_tensors(*tup) - _concat_check(tensors, dtype, out=None) - result = _impl.hstack(tensors, dtype=dtype, casting=casting) - return asarray(result) +@normalizer +def hstack(tup: Sequence[ArrayLike], *, dtype: DTypeLike = None, casting="same_kind"): + _concat_check(tup, dtype, out=None) + result = _impl.hstack(tup, dtype=dtype, casting=casting) + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def dstack(tup, *, dtype=None, casting="same_kind"): +@normalizer +def dstack(tup: Sequence[ArrayLike], *, dtype: DTypeLike = None, casting="same_kind"): # XXX: in numpy 1.24 dstack does not have dtype and casting keywords # but {h,v}stack do. Hence add them here for consistency. - tensors = _helpers.to_tensors(*tup) - result = _impl.dstack(tensors, dtype=dtype, casting=casting) - return asarray(result) + result = _impl.dstack(tup, dtype=dtype, casting=casting) + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def column_stack(tup, *, dtype=None, casting="same_kind"): +@normalizer +def column_stack( + tup: Sequence[ArrayLike], *, dtype: DTypeLike = None, casting="same_kind" +): # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords # but row_stack does. (because row_stack is an alias for vstack, really). # Hence add these keywords here for consistency. - tensors = _helpers.to_tensors(*tup) - _concat_check(tensors, dtype, out=None) - result = _impl.column_stack(tensors, dtype=dtype, casting=casting) - return asarray(result) + _concat_check(tup, dtype, out=None) + result = _impl.column_stack(tup, dtype=dtype, casting=casting) + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"): - tensors = _helpers.to_tensors(*arrays) - _concat_check(tensors, dtype, out=out) - result = _impl.stack(tensors, axis=axis, out=out, dtype=dtype, casting=casting) +@normalizer +def stack( + arrays: Sequence[ArrayLike], + axis=0, + out: Optional[NDArray] = None, + *, + dtype: DTypeLike = None, + casting="same_kind", +): + _concat_check(arrays, dtype, out=out) + result = _impl.stack(arrays, axis=axis, out=out, dtype=dtype, casting=casting) return _helpers.result_or_out(result, out) @@ -198,143 +204,177 @@ def dsplit(ary, indices_or_sections): return tuple(maybe_set_base(x, base) for x in result) -def kron(a, b): - a_tensor, b_tensor = _helpers.to_tensors(a, b) - result = torch.kron(a_tensor, b_tensor) - return asarray(result) +@normalizer +def kron(a: ArrayLike, b: ArrayLike): + result = torch.kron(a, b) + return _helpers.array_from(result) -def tile(A, reps): - a_tensor = asarray(A).get() +@normalizer +def tile(A: ArrayLike, reps): if isinstance(reps, int): reps = (reps,) - result = torch.tile(a_tensor, reps) - return asarray(result) + result = torch.tile(A, reps) + return _helpers.array_from(result) -def vander(x, N=None, increasing=False): - x_tensor = asarray(x).get() - result = torch.vander(x_tensor, N, increasing) - return asarray(result) +@normalizer +def vander(x: ArrayLike, N=None, increasing=False): + result = torch.vander(x, N, increasing) + return _helpers.array_from(result) def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): if axis != 0 or retstep or not endpoint: raise NotImplementedError # XXX: raises TypeError if start or stop are not scalars - return asarray(torch.linspace(start, stop, num, dtype=dtype)) - - -@_decorators.dtype_to_torch -def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): + result = torch.linspace(start, stop, num, dtype=dtype) + return _helpers.array_from(result) + + +@normalizer +def geomspace( + start: ArrayLike, + stop: ArrayLike, + num=50, + endpoint=True, + dtype: DTypeLike = None, + axis=0, +): if axis != 0 or not endpoint: raise NotImplementedError - start, stop = _helpers.to_tensors(start, stop) result = _impl.geomspace(start, stop, num, endpoint, dtype, axis) - return asarray(result) + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): +@normalizer +def logspace( + start, stop, num=50, endpoint=True, base=10.0, dtype: DTypeLike = None, axis=0 +): if axis != 0 or not endpoint: raise NotImplementedError - return asarray(torch.logspace(start, stop, num, base=base, dtype=dtype)) + result = torch.logspace(start, stop, num, base=base, dtype=dtype) + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def arange(start=None, stop=None, step=1, dtype=None, *, like=None): - _util.subok_not_ok(like) - start, stop, step = _helpers.ndarrays_to_tensors(start, stop, step) +@normalizer +def arange( + start: Optional[ArrayLike] = None, + stop: Optional[ArrayLike] = None, + step: Optional[ArrayLike] = 1, + dtype: DTypeLike = None, + *, + like: SubokLike = None, +): result = _impl.arange(start, stop, step, dtype=dtype) - return asarray(result) + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def empty(shape, dtype=float, order="C", *, like=None): - _util.subok_not_ok(like) +@normalizer +def empty(shape, dtype: DTypeLike = float, order="C", *, like: SubokLike = None): if order != "C": raise NotImplementedError if dtype is None: dtype = _dtypes_impl.default_float_dtype result = torch.empty(shape, dtype=dtype) - return asarray(result) + return _helpers.array_from(result) -# NB: *_like function deliberately deviate from numpy: it has subok=True +# NB: *_like functions deliberately deviate from numpy: it has subok=True # as the default; we set subok=False and raise on anything else. -@asarray_replacer() -@_decorators.dtype_to_torch -def empty_like(prototype, dtype=None, order="K", subok=False, shape=None): - _util.subok_not_ok(subok=subok) +@normalizer +def empty_like( + prototype: ArrayLike, + dtype: DTypeLike = None, + order="K", + subok: SubokLike = False, + shape=None, +): if order != "K": raise NotImplementedError result = _impl.empty_like(prototype, dtype=dtype, shape=shape) - return result + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def full(shape, fill_value, dtype=None, order="C", *, like=None): - _util.subok_not_ok(like) +@normalizer +def full( + shape, + fill_value: ArrayLike, + dtype: DTypeLike = None, + order="C", + *, + like: SubokLike = None, +): if isinstance(shape, int): shape = (shape,) if order != "C": raise NotImplementedError - fill_value = asarray(fill_value).get() result = _impl.full(shape, fill_value, dtype=dtype) - return asarray(result) + return _helpers.array_from(result) -@asarray_replacer() -@_decorators.dtype_to_torch -def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None): - _util.subok_not_ok(subok=subok) +@normalizer +def full_like( + a: ArrayLike, + fill_value, + dtype: DTypeLike = None, + order="K", + subok: SubokLike = False, + shape=None, +): if order != "K": raise NotImplementedError result = _impl.full_like(a, fill_value, dtype=dtype, shape=shape) - return result + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def ones(shape, dtype=None, order="C", *, like=None): - _util.subok_not_ok(like) +@normalizer +def ones(shape, dtype: DTypeLike = None, order="C", *, like: SubokLike = None): if order != "C": raise NotImplementedError if dtype is None: dtype = _dtypes_impl.default_float_dtype result = torch.ones(shape, dtype=dtype) - return asarray(result) + return _helpers.array_from(result) -@asarray_replacer() -@_decorators.dtype_to_torch -def ones_like(a, dtype=None, order="K", subok=False, shape=None): - _util.subok_not_ok(subok=subok) +@normalizer +def ones_like( + a: ArrayLike, + dtype: DTypeLike = None, + order="K", + subok: SubokLike = False, + shape=None, +): if order != "K": raise NotImplementedError result = _impl.ones_like(a, dtype=dtype, shape=shape) - return result + return _helpers.array_from(result) -@_decorators.dtype_to_torch -def zeros(shape, dtype=None, order="C", *, like=None): - _util.subok_not_ok(like) +@normalizer +def zeros(shape, dtype: DTypeLike = None, order="C", *, like: SubokLike = None): if order != "C": raise NotImplementedError if dtype is None: dtype = _dtypes_impl.default_float_dtype result = torch.zeros(shape, dtype=dtype) - return asarray(result) + return _helpers.array_from(result) -@asarray_replacer() -@_decorators.dtype_to_torch -def zeros_like(a, dtype=None, order="K", subok=False, shape=None): - _util.subok_not_ok(subok=subok) +@normalizer +def zeros_like( + a: ArrayLike, + dtype: DTypeLike = None, + order="K", + subok: SubokLike = False, + shape=None, +): if order != "K": raise NotImplementedError result = _impl.zeros_like(a, dtype=dtype, shape=shape) - return result + return _helpers.array_from(result) ###### misc/unordered @@ -365,77 +405,80 @@ def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True): return x_tensor -@_decorators.dtype_to_torch -def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None): +@normalizer +def corrcoef( + x: ArrayLike, + y: Optional[ArrayLike] = None, + rowvar=True, + bias=NoValue, + ddof=NoValue, + *, + dtype: DTypeLike = None, +): if bias is not None or ddof is not None: # deprecated in NumPy raise NotImplementedError - - x_tensor, y_tensor = _helpers.to_tensors_or_none(x, y) - tensor = _xy_helper_corrcoef(x_tensor, y_tensor, rowvar) + tensor = _xy_helper_corrcoef(x, y, rowvar) result = _impl.corrcoef(tensor, dtype=dtype) - return asarray(result) + return _helpers.array_from(result) -@_decorators.dtype_to_torch +@normalizer def cov( - m, - y=None, + m: ArrayLike, + y: Optional[ArrayLike] = None, rowvar=True, bias=False, ddof=None, - fweights=None, - aweights=None, + fweights: Optional[ArrayLike] = None, + aweights: Optional[ArrayLike] = None, *, - dtype=None, + dtype: DTypeLike = None, ): - - m_tensor, y_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none( - m, y, fweights, aweights - ) - m_tensor = _xy_helper_corrcoef(m_tensor, y_tensor, rowvar) - - result = _impl.cov( - m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype - ) - return asarray(result) + m = _xy_helper_corrcoef(m, y, rowvar) + result = _impl.cov(m, bias, ddof, fweights, aweights, dtype=dtype) + return _helpers.array_from(result) -def bincount(x, /, weights=None, minlength=0): - if not isinstance(x, ndarray) and x == []: +@normalizer +def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0): + if x.numel() == 0: # edge case allowed by numpy - x = asarray([], dtype=int) - - x_tensor, weights_tensor = _helpers.to_tensors_or_none(x, weights) - result = _impl.bincount(x_tensor, weights_tensor, minlength) - return asarray(result) + x = torch.as_tensor([], dtype=int) + result = _impl.bincount(x, weights, minlength) + return _helpers.array_from(result) -def where(condition, x=None, y=None, /): - cond_t, x_t, y_t = _helpers.to_tensors_or_none(condition, x, y) - result = _impl.where(cond_t, x_t, y_t) +@normalizer +def where( + condition: ArrayLike, + x: Optional[ArrayLike] = None, + y: Optional[ArrayLike] = None, + /, +): + result = _impl.where(condition, x, y) if isinstance(result, tuple): # single-argument where(condition) - return tuple(asarray(x) for x in result) + return _helpers.tuple_arrays_from(result) else: - return asarray(result) + return _helpers.array_from(result) ###### module-level queries of object properties -def ndim(a): - a = asarray(a).get() +@normalizer +def ndim(a: ArrayLike): return a.ndim -def shape(a): - a = asarray(a).get() +@normalizer +def shape(a: ArrayLike): return tuple(a.shape) -def size(a, axis=None): - a = asarray(a).get() +@normalizer +def size(a: ArrayLike, axis=None): if axis is None: return a.numel() else: @@ -445,48 +488,51 @@ def size(a, axis=None): ###### shape manipulations and indexing -def expand_dims(a, axis): - a = asarray(a) +@normalizer +def expand_dims(a: ArrayLike, axis): shape = _util.expand_shape(a.shape, axis) - tensor = a.get().view(shape) # never copies - return ndarray._from_tensor_and_base(tensor, a) + tensor = a.view(shape) # never copies + return _helpers.array_from(tensor, a) -@asarray_replacer() -def flip(m, axis=None): - return _flips.flip(m, axis) +@normalizer +def flip(m: ArrayLike, axis=None): + result = _flips.flip(m, axis) + return _helpers.array_from(result) -@asarray_replacer() -def flipud(m): - return _flips.flipud(m) +@normalizer +def flipud(m: ArrayLike): + result = _flips.flipud(m) + return _helpers.array_from(result) -@asarray_replacer() -def fliplr(m): - return _flips.fliplr(m) +@normalizer +def fliplr(m: ArrayLike): + result = _flips.fliplr(m) + return _helpers.array_from(result) -@asarray_replacer() -def rot90(m, k=1, axes=(0, 1)): - return _flips.rot90(m, k, axes) +@normalizer +def rot90(m: ArrayLike, k=1, axes=(0, 1)): + result = _flips.rot90(m, k, axes) + return _helpers.array_from(result) -@asarray_replacer() -def broadcast_to(array, shape, subok=False): - _util.subok_not_ok(subok=subok) - return torch.broadcast_to(array, size=shape) +@normalizer +def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False): + result = torch.broadcast_to(array, size=shape) + return _helpers.array_from(result) from torch import broadcast_shapes # YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero -def broadcast_arrays(*args, subok=False): - _util.subok_not_ok(subok=subok) - tensors = _helpers.to_tensors(*args) - res = torch.broadcast_tensors(*tensors) - return tuple(asarray(_) for _ in res) +@normalizer +def broadcast_arrays(*args: ArrayLike, subok: SubokLike = False): + res = torch.broadcast_tensors(*args) + return _helpers.tuple_arrays_from(res) def unravel_index(indices, shape, order="C"): @@ -510,19 +556,20 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): return sum(idx * dim for idx, dim in zip(multi_index, dims)) -def meshgrid(*xi, copy=True, sparse=False, indexing="xy"): - xi_tensors = _helpers.to_tensors(*xi) - output = _impl.meshgrid(*xi_tensors, copy=copy, sparse=sparse, indexing=indexing) - return [asarray(t) for t in output] +@normalizer +def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"): + output = _impl.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) + outp = _helpers.tuple_arrays_from(output) + return list(outp) # match numpy, return a list -@_decorators.dtype_to_torch -def indices(dimensions, dtype=int, sparse=False): +@normalizer +def indices(dimensions, dtype: DTypeLike = int, sparse=False): result = _impl.indices(dimensions, dtype=dtype, sparse=sparse) if sparse: - return tuple(asarray(x) for x in result) + return _helpers.tuple_arrays_from(result) else: - return asarray(result) + return _helpers.array_from(result) def flatnonzero(a): @@ -530,246 +577,125 @@ def flatnonzero(a): return _funcs.nonzero(arr.ravel())[0] -from ._decorators import emulate_out_arg -from ._ndarray import axis_keepdims_wrapper - -count_nonzero = emulate_out_arg(axis_keepdims_wrapper(_reductions.count_nonzero)) - - -def roll(a, shift, axis=None): - tensor = asarray(a).get() - result = _impl.roll(tensor, shift, axis) - return asarray(result) +@normalizer +def roll(a: ArrayLike, shift, axis=None): + result = _impl.roll(a, shift, axis) + return _helpers.array_from(result) ###### tri{l, u} and related -@asarray_replacer() -def tril(m, k=0): - return m.tril(k) +@normalizer +def tril(m: ArrayLike, k=0): + result = m.tril(k) + return _helpers.array_from(result) -@asarray_replacer() -def triu(m, k=0): - return m.triu(k) +@normalizer +def triu(m: ArrayLike, k=0): + result = m.triu(k) + return _helpers.array_from(result) def tril_indices(n, k=0, m=None): result = _impl.tril_indices(n, k, m) - return tuple(asarray(t) for t in result) + return _helpers.tuple_arrays_from(result) def triu_indices(n, k=0, m=None): result = _impl.triu_indices(n, k, m) - return tuple(asarray(t) for t in result) + return _helpers.tuple_arrays_from(result) -def tril_indices_from(arr, k=0): - tensor = asarray(arr).get() - result = _impl.tril_indices_from(tensor, k) - return tuple(asarray(t) for t in result) +@normalizer +def tril_indices_from(arr: ArrayLike, k=0): + result = _impl.tril_indices_from(arr, k) + return _helpers.tuple_arrays_from(result) -def triu_indices_from(arr, k=0): - tensor = asarray(arr).get() - result = _impl.triu_indices_from(tensor, k) - return tuple(asarray(t) for t in result) +@normalizer +def triu_indices_from(arr: ArrayLike, k=0): + result = _impl.triu_indices_from(arr, k) + return _helpers.tuple_arrays_from(result) -@_decorators.dtype_to_torch -def tri(N, M=None, k=0, dtype=float, *, like=None): - _util.subok_not_ok(like) +@normalizer +def tri(N, M=None, k=0, dtype: DTypeLike = float, *, like: SubokLike = None): result = _impl.tri(N, M, k, dtype) - return asarray(result) + return _helpers.array_from(result) ###### reductions -def argmax(a, axis=None, out=None, *, keepdims=NoValue): - arr = asarray(a) - return arr.argmax(axis=axis, out=out, keepdims=keepdims) - - -def argmin(a, axis=None, out=None, *, keepdims=NoValue): - arr = asarray(a) - return arr.argmin(axis=axis, out=out, keepdims=keepdims) - - -def amax(a, axis=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue): - arr = asarray(a) - return arr.max(axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - -max = amax - - -def amin(a, axis=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue): - arr = asarray(a) - return arr.min(axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - -min = amin - - -def ptp(a, axis=None, out=None, keepdims=NoValue): - arr = asarray(a) - return arr.ptp(axis=axis, out=out, keepdims=keepdims) - - -def all(a, axis=None, out=None, keepdims=NoValue, *, where=NoValue): - arr = asarray(a) - return arr.all(axis=axis, out=out, keepdims=keepdims, where=where) - - -def any(a, axis=None, out=None, keepdims=NoValue, *, where=NoValue): - arr = asarray(a) - return arr.any(axis=axis, out=out, keepdims=keepdims, where=where) - - -def mean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue): - arr = asarray(a) - return arr.mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) - -# YYY: pattern: initial=... - -def sum( - a, axis=None, dtype=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue -): - arr = asarray(a) - return arr.sum( - axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where - ) - - -def prod( - a, axis=None, dtype=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue +@normalizer +def average( + a: ArrayLike, + axis=None, + weights: ArrayLike = None, + returned=False, + *, + keepdims=NoValue, ): - arr = asarray(a) - return arr.prod( - axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where - ) - - -product = prod - - -def cumprod(a, axis=None, dtype=None, out=None): - arr = asarray(a) - return arr.cumprod(axis=axis, dtype=dtype, out=out) - - -cumproduct = cumprod - - -def cumsum(a, axis=None, dtype=None, out=None): - arr = asarray(a) - return arr.cumsum(axis=axis, dtype=dtype, out=out) - - -# YYY: pattern : ddof - - -def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue): - arr = asarray(a) - return arr.std( - axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where - ) - - -def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue): - arr = asarray(a) - return arr.var( - axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where + result, wsum = _reductions.average( + a, axis, weights, returned=returned, keepdims=keepdims ) - - -def average(a, axis=None, weights=None, returned=False, *, keepdims=NoValue): - - if weights is None: - result = mean(a, axis=axis, keepdims=keepdims) - if returned: - scl = result.dtype.type(a.size / result.size) - return result, scl - return result - - a_tensor, w_tensor = _helpers.to_tensors(a, weights) - - result, wsum = _reductions.average(a_tensor, axis, w_tensor) - - # keepdims - if keepdims: - result = _util.apply_keepdims(result, axis, a_tensor.ndim) - - # returned if returned: - scl = wsum - if scl.shape != result.shape: - scl = torch.broadcast_to(scl, result.shape).clone() - - return asarray(result), asarray(scl) + return _helpers.tuple_arrays_from((result, wsum)) else: - return asarray(result) + return _helpers.array_from(result) +# Normalizations (ArrayLike et al) in percentile and median are done in `_funcs.py/quantile`. def percentile( a, q, axis=None, - out=None, + out: Optional[NDArray] = None, overwrite_input=False, method="linear", keepdims=False, *, interpolation=None, ): - return quantile( + return _funcs.quantile( a, asarray(q) / 100.0, axis, out, overwrite_input, method, keepdims=keepdims ) -def quantile( - a, - q, - axis=None, - out=None, - overwrite_input=False, - method="linear", - keepdims=False, - *, - interpolation=None, +def median( + a, axis=None, out: Optional[NDArray] = None, overwrite_input=False, keepdims=False ): - if interpolation is not None: - raise ValueError("'interpolation' argument is deprecated; use 'method' instead") - - a_tensor, q_tensor = _helpers.to_tensors(a, q) - result = _reductions.quantile(a_tensor, q_tensor, axis, method) - - # keepdims - if keepdims: - result = _util.apply_keepdims(result, axis, a_tensor.ndim) - return _helpers.result_or_out(result, out, promote_scalar=True) - - -def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): - return quantile( + return _funcs.quantile( a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims ) -def inner(a, b, /): - t_a, t_b = _helpers.to_tensors(a, b) - result = _impl.inner(t_a, t_b) - return asarray(result) +@normalizer +def inner(a: ArrayLike, b: ArrayLike, /): + result = _impl.inner(a, b) + return _helpers.array_from(result) -def outer(a, b, out=None): - a_t, b_t = _helpers.to_tensors(a, b) - result = torch.outer(a_t, b_t) +@normalizer +def outer(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None): + result = torch.outer(a, b) return _helpers.result_or_out(result, out) -@asarray_replacer() -def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue): +# ### FIXME: this is a stub + + +@normalizer +def nanmean( + a: ArrayLike, + axis=None, + dtype: DTypeLike = None, + out: Optional[NDArray] = None, + keepdims=NoValue, + *, + where=NoValue, +): + # XXX: this needs to be rewritten if where is not NoValue: raise NotImplementedError if dtype is None: @@ -782,7 +708,7 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal result = a.nanmean(dtype=dtype, dim=axis, keepdim=bool(keepdims)) if out is not None: out.copy_(result) - return result + return _helpers.array_from(result) def nanmin(): @@ -837,81 +763,91 @@ def nanpercentile(): raise NotImplementedError -def diff(a, n=1, axis=-1, prepend=NoValue, append=NoValue): +@normalizer +def diff( + a: ArrayLike, + n=1, + axis=-1, + prepend: Optional[ArrayLike] = NoValue, + append: Optional[ArrayLike] = NoValue, +): if n == 0: # match numpy and return the input immediately - return a - - a_tensor, prepend_tensor, append_tensor = _helpers.to_tensors_or_none( - a, prepend, append - ) + return _helpers.array_from(result) result = _impl.diff( - a_tensor, + a, n=n, axis=axis, - prepend_tensor=prepend_tensor, - append_tensor=append_tensor, + prepend_tensor=prepend, + append_tensor=append, ) - return asarray(result) + return _helpers.array_from(result) ##### math functions -@asarray_replacer() -def angle(z, deg=False): +@normalizer +def angle(z: ArrayLike, deg=False): result = _impl.angle(z, deg) - return result + return _helpers.array_from(result) -@asarray_replacer() -def sinc(x): - return torch.sinc(x) +@normalizer +def sinc(x: ArrayLike): + result = torch.sinc(x) + return _helpers.array_from(result) -@asarray_replacer() -def real_if_close(a, tol=100): +@normalizer +def real_if_close(a: ArrayLike, tol=100): result = _impl.real_if_close(a, tol=tol) - return result + return _helpers.array_from(result) -@asarray_replacer() -def iscomplex(x): +@normalizer +def iscomplex(x: ArrayLike): result = _impl.iscomplex(x) - return result # XXX: missing .item on a zero-dim value; a case for array_or_scalar(value) ? + # XXX: missing .item on a zero-dim value; a case for array_or_scalar(value) ? + return _helpers.array_from(result) -@asarray_replacer() -def isreal(x): +@normalizer +def isreal(x: ArrayLike): result = _impl.isreal(x) - return result + return _helpers.array_from(result) -@asarray_replacer() -def iscomplexobj(x): - return torch.is_complex(x) +@normalizer +def iscomplexobj(x: ArrayLike): + result = torch.is_complex(x) + return result -@asarray_replacer() -def isrealobj(x): - return not torch.is_complex(x) +@normalizer +def isrealobj(x: ArrayLike): + result = not torch.is_complex(x) + return result -@asarray_replacer() -def isneginf(x, out=None): - return torch.isneginf(x, out=out) +@normalizer +def isneginf(x: ArrayLike, out: Optional[NDArray] = None): + result = torch.isneginf(x, out=out) + return _helpers.array_from(result) -@asarray_replacer() -def isposinf(x, out=None): - return torch.isposinf(x, out=out) +@normalizer +def isposinf(x: ArrayLike, out: Optional[NDArray] = None): + result = torch.isposinf(x, out=out) + return _helpers.array_from(result) -@asarray_replacer() -def i0(x): - return torch.special.i0(x) +@normalizer +def i0(x: ArrayLike): + result = torch.special.i0(x) + return _helpers.array_from(result) def isscalar(a): @@ -923,27 +859,27 @@ def isscalar(a): return False -def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): - a_t, b_t = _helpers.to_tensors(a, b) - result = _impl.isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) - return asarray(result) +@normalizer +def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): + result = _impl.isclose(a, b, rtol, atol, equal_nan=equal_nan) + return _helpers.array_from(result) -def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - a_t, b_t = _helpers.to_tensors(a, b) - result = _impl.isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) +@normalizer +def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False): + result = _impl.isclose(a, b, rtol, atol, equal_nan=equal_nan) return result.all() -def array_equal(a1, a2, equal_nan=False): - a1_t, a2_t = _helpers.to_tensors(a1, a2) - result = _impl.tensor_equal(a1_t, a2_t, equal_nan) +@normalizer +def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False): + result = _impl.tensor_equal(a1, a2, equal_nan) return result -def array_equiv(a1, a2): - a1_t, a2_t = _helpers.to_tensors(a1, a2) - result = _impl.tensor_equiv(a1_t, a2_t) +@normalizer +def array_equiv(a1: ArrayLike, a2: ArrayLike): + result = _impl.tensor_equiv(a1, a2) return result @@ -966,24 +902,26 @@ def asfarray(): # ### put/take_along_axis ### -def take_along_axis(arr, indices, axis): - tensor, t_indices = _helpers.to_tensors(arr, indices) - result = _impl.take_along_dim(tensor, t_indices, axis) - return asarray(result) +@normalizer +def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): + result = _impl.take_along_dim(arr, indices, axis) + return _helpers.array_from(result) -def put_along_axis(arr, indices, values, axis): - tensor, t_indices, t_values = _helpers.to_tensors(arr, indices, values) - # modify the argument in-place - arr._tensor = _impl.put_along_dim(tensor, t_indices, t_values, axis) +@normalizer +def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): + # modify the argument in-place : here `arr` is `arr._tensor` of the orignal `arr` argument + result = _impl.put_along_dim(arr, indices, values, axis) + arr.copy_(result.reshape(arr.shape)) return None # ### unqiue et al ### +@normalizer def unique( - ar, + ar: ArrayLike, return_index=False, return_inverse=False, return_counts=False, @@ -991,9 +929,8 @@ def unique( *, equal_nan=True, ): - tensor = asarray(ar).get() result = _impl.unique( - tensor, + ar, return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, @@ -1002,9 +939,9 @@ def unique( ) if isinstance(result, tuple): - return tuple(asarray(x) for x in result) + return _helpers.tuple_arrays_from(result) else: - return asarray(result) + return _helpers.array_from(result) ###### mapping from numpy API objects to wrappers from this module ###### diff --git a/torch_np/random.py b/torch_np/random.py index 4bc3426c..21d5faa2 100644 --- a/torch_np/random.py +++ b/torch_np/random.py @@ -6,11 +6,13 @@ """ from math import sqrt +from typing import Optional import torch -from . import asarray +from . import _helpers from ._detail import _dtypes_impl, _util +from ._normalizations import ArrayLike, normalizer _default_dtype = _dtypes_impl.default_float_dtype @@ -33,7 +35,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False): if return_scalar: return py_type(values.item()) else: - return asarray(values) + return _helpers.array_from(values) def seed(seed=None): @@ -75,11 +77,11 @@ def normal(loc=0.0, scale=1.0, size=None): return array_or_scalar(values, return_scalar=size is None) -def shuffle(x): - x_tensor = asarray(x).get() - perm = torch.randperm(x_tensor.shape[0]) - xp = x_tensor[perm] - x_tensor.copy_(xp) +@normalizer +def shuffle(x: ArrayLike): + perm = torch.randperm(x.shape[0]) + xp = x[perm] + x.copy_(xp) def randint(low, high=None, size=None): @@ -93,12 +95,14 @@ def randint(low, high=None, size=None): return array_or_scalar(values, int, return_scalar=size is None) -def choice(a, size=None, replace=True, p=None): +@normalizer +def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): + # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch - if isinstance(a, int): - a_tensor = torch.arange(a) - else: - a_tensor = asarray(a).get() + if a.numel() == 1: + a = torch.arange(a) + + # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises # number of draws if size is None: @@ -112,21 +116,19 @@ def choice(a, size=None, replace=True, p=None): # prepare the probabilities if p is None: - p_tensor = torch.ones_like(a_tensor) / a_tensor.shape[0] - else: - p_tensor = asarray(p, dtype=float).get() + p = torch.ones_like(a) / a.shape[0] # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973 atol = sqrt(torch.finfo(torch.float64).eps) - if abs(p_tensor.sum() - 1.0) > atol: + if abs(p.sum() - 1.0) > atol: raise ValueError("probabilities do not sum to 1.") # actually sample - indices = torch.multinomial(p_tensor, num_el, replacement=replace) + indices = torch.multinomial(p, num_el, replacement=replace) if _util.is_sequence(size): indices = indices.reshape(size) - samples = a_tensor[indices] + samples = a[indices] - return asarray(samples) + return _helpers.array_from(samples) diff --git a/torch_np/tests/numpy_tests/core/test_multiarray.py b/torch_np/tests/numpy_tests/core/test_multiarray.py index 49e8f01d..8ec56146 100644 --- a/torch_np/tests/numpy_tests/core/test_multiarray.py +++ b/torch_np/tests/numpy_tests/core/test_multiarray.py @@ -3743,7 +3743,6 @@ def test_ret_is_out(self, ndim, method): ret = arg_method(axis=0, out=out) assert ret is out - @pytest.mark.xfail(reason='FIXME: out w/ positional args?') @pytest.mark.parametrize('arr_method, np_method', [('argmax', np.argmax), ('argmin', np.argmin)]) diff --git a/torch_np/tests/numpy_tests/core/test_numeric.py b/torch_np/tests/numpy_tests/core/test_numeric.py index 9b0e7df4..0d1b01ab 100644 --- a/torch_np/tests/numpy_tests/core/test_numeric.py +++ b/torch_np/tests/numpy_tests/core/test_numeric.py @@ -2486,7 +2486,6 @@ def test_mode(self): class TestDtypePositional: - @pytest.mark.xfail(reason='TODO: restore dtypes as positional args') def test_dtype_positional(self): np.empty((2,), bool) diff --git a/torch_np/tests/numpy_tests/lib/test_function_base.py b/torch_np/tests/numpy_tests/lib/test_function_base.py index 4bb19689..c6ea3eaf 100644 --- a/torch_np/tests/numpy_tests/lib/test_function_base.py +++ b/torch_np/tests/numpy_tests/lib/test_function_base.py @@ -732,7 +732,7 @@ def test_n(self): assert_raises(ValueError, diff, x, n=-1) output = [diff(x, n=n) for n in range(1, 5)] expected = [[1, 1], [0], [], []] - assert_(diff(x, n=0) is x) + ## assert_(diff(x, n=0) is x) for n, (expected, out) in enumerate(zip(expected, output), start=1): assert_(type(out) is np.ndarray) assert_array_equal(out, expected) diff --git a/torch_np/tests/numpy_tests/lib/test_shape_base_.py b/torch_np/tests/numpy_tests/lib/test_shape_base_.py index 63aa0b24..cd43c8fb 100644 --- a/torch_np/tests/numpy_tests/lib/test_shape_base_.py +++ b/torch_np/tests/numpy_tests/lib/test_shape_base_.py @@ -597,7 +597,7 @@ def test_basic(self): assert type(res) is np.ndarray aa = np.ones((3, 1, 4, 1, 1)) - assert aa.squeeze().base is aa + assert aa.squeeze().get()._base is aa.get() def test_squeeze_axis(self): A = [[[1, 1, 1], [2, 2, 2], [3, 3, 3]]] @@ -719,7 +719,7 @@ def test_kroncompare(self): for s in shape: b = randint(0, 10, size=s) for r in reps: - a = np.ones(r, dtype=b.dtype) # TODO: restore dtype positional arg + a = np.ones(r, b.dtype) large = tile(b, r) klarge = kron(a, b) assert_equal(large, klarge) diff --git a/torch_np/tests/test_basic.py b/torch_np/tests/test_basic.py index f102793d..23ea784a 100644 --- a/torch_np/tests/test_basic.py +++ b/torch_np/tests/test_basic.py @@ -1,8 +1,9 @@ import functools -import numpy as np +import numpy as _np import pytest import torch +from pytest import raises as assert_raises import torch_np as w import torch_np._unary_ufuncs as _unary_ufuncs @@ -25,9 +26,7 @@ w.angle, w.real_if_close, w.isreal, - w.isrealobj, w.iscomplex, - w.iscomplexobj, w.isneginf, w.isposinf, w.i0, @@ -44,9 +43,10 @@ w.flatnonzero, ] +ufunc_names = _unary_ufuncs.__all__ +ufunc_names.remove("invert") # torch: bitwise_not_cpu not implemented for 'Float' -one_arg_funcs += [getattr(_unary_ufuncs, name) for name in _unary_ufuncs.__all__] -one_arg_funcs = one_arg_funcs[:-1] # FIXME: remove np.invert +one_arg_funcs += [getattr(_unary_ufuncs, name) for name in ufunc_names] @pytest.mark.parametrize("func", one_arg_funcs) @@ -211,7 +211,7 @@ def test_array(self, func): assert ta.shape == self.shape -one_arg_scalar_funcs = [(w.size, np.size), (w.shape, np.shape), (w.ndim, np.ndim)] +one_arg_scalar_funcs = [(w.size, _np.size), (w.shape, _np.shape), (w.ndim, _np.ndim)] @pytest.mark.parametrize("func, np_func", one_arg_scalar_funcs) @@ -221,7 +221,7 @@ class TestOneArrToScalar: def test_tensor(self, func, np_func): t = torch.Tensor([[1, 2, 3], [4, 5, 6]]) ta = func(t) - tn = np_func(np.asarray(t)) + tn = np_func(_np.asarray(t)) assert not isinstance(ta, w.ndarray) assert ta == tn @@ -384,3 +384,31 @@ class TestPythonArgsToArray: def test_simple(self, func, args): a = func(*args) assert isinstance(a, w.ndarray) + + +class TestNormalizations: + """Smoke test generic problems with normalizations.""" + + def test_unknown_args(self): + # Check that unknown args to decorated functions fail + a = w.arange(7) % 2 == 0 + + # unknown positional args + with assert_raises(TypeError): + w.nonzero(a, "kaboom") + + # unknown kwarg + with assert_raises(TypeError): + w.nonzero(a, oops="ouch") + + def test_too_few_args_positional(self): + with assert_raises(TypeError): + w.nonzero() + + def test_unknown_args_with_defaults(self): + # check a function 5 arguments and 4 defaults: this should work + w.eye(3) + + # five arguments, four defaults: this should fail + with assert_raises(TypeError): + w.eye() diff --git a/torch_np/tests/test_ndarray_methods.py b/torch_np/tests/test_ndarray_methods.py index 51608dfd..c31aaf9a 100644 --- a/torch_np/tests/test_ndarray_methods.py +++ b/torch_np/tests/test_ndarray_methods.py @@ -17,7 +17,7 @@ def test_indexing_simple(self): assert isinstance(a[0, 0], np.ndarray) assert isinstance(a[0, :], np.ndarray) - assert a[0, :].base is a + assert a[0, :].get()._base is a.get() def test_setitem(self): a = np.array([[1, 2, 3], [4, 5, 6]]) @@ -33,7 +33,7 @@ def test_reshape_function(self): assert np.all(np.reshape(arr, (2, 6)) == tgt) arr = np.asarray(arr) - assert np.transpose(arr, (1, 0)).base is arr + assert np.transpose(arr, (1, 0)).get()._base is arr.get() def test_reshape_method(self): arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) @@ -43,24 +43,24 @@ def test_reshape_method(self): # reshape(*shape_tuple) assert np.all(arr.reshape(2, 6) == tgt) - assert arr.reshape(2, 6).base is arr # reshape keeps the base + assert arr.reshape(2, 6).get()._base is arr.get() # reshape keeps the base assert arr.shape == arr_shape # arr is intact # XXX: move out to dedicated test(s) - assert arr.reshape(2, 6)._tensor._base is arr._tensor + assert arr.reshape(2, 6).get()._base is arr.get() # reshape(shape_tuple) assert np.all(arr.reshape((2, 6)) == tgt) - assert arr.reshape((2, 6)).base is arr + assert arr.reshape((2, 6)).get()._base is arr.get() assert arr.shape == arr_shape tgt = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] assert np.all(arr.reshape(3, 4) == tgt) - assert arr.reshape(3, 4).base is arr + assert arr.reshape(3, 4).get()._base is arr.get() assert arr.shape == arr_shape assert np.all(arr.reshape((3, 4)) == tgt) - assert arr.reshape((3, 4)).base is arr + assert arr.reshape((3, 4)).get()._base is arr.get() assert arr.shape == arr_shape @@ -82,7 +82,7 @@ def test_transpose_function(self): assert_equal(np.transpose(arr, (1, 0)), tgt) arr = np.asarray(arr) - assert np.transpose(arr, (1, 0)).base is arr + assert np.transpose(arr, (1, 0)).get()._base is arr.get() def test_transpose_method(self): a = np.array([[1, 2], [3, 4]]) @@ -92,7 +92,7 @@ def test_transpose_method(self): assert_raises(ValueError, lambda: a.transpose(0, 0)) assert_raises(ValueError, lambda: a.transpose(0, 1, 2)) - assert a.transpose().base is a + assert a.transpose().get()._base is a.get() class TestRavel: @@ -102,13 +102,13 @@ def test_ravel_function(self): assert_equal(np.ravel(a), tgt) arr = np.asarray(a) - assert np.ravel(arr).base is arr + assert np.ravel(arr).get()._base is arr.get() def test_ravel_method(self): a = np.array([[0, 1], [2, 3]]) assert_equal(a.ravel(), [0, 1, 2, 3]) - assert a.ravel().base is a + assert a.ravel().get()._base is a.get() class TestNonzero: @@ -323,7 +323,6 @@ def test_np_vs_ndarray(self, arr_method, np_method): assert_equal(arg_method(out=out1, axis=0), np_method(a, out=out2, axis=0)) assert_equal(out1, out2) - @pytest.mark.xfail(reason="out=... as a positional arg") @pytest.mark.parametrize( "arr_method, np_method", [("argmax", np.argmax), ("argmin", np.argmin)] )