From 02145bb581c7e18da0f190f61d6abd0f7d16e2ee Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 26 Apr 2023 11:37:22 +0000 Subject: [PATCH 1/2] Fix and minors Don't use out= explicitly, improved a bit the implementation of `average`. Minor improvements here and there e.g. - prefer flatten over ravel as it's more PyTorch-y - Prefer `.double()` or `.long()` over `to(float)` (I didn't even know that worked...) for the same reason - Don't call .item() if we can avoid it (added to the list of differences) - remove _mappings - remove the need of semi_private methods - Fixed ndarray.fill, as the normalizer was not working because of the future annotations --- torch_np/_funcs.py | 9 +-- torch_np/_funcs_impl.py | 122 +++++++++------------------------ torch_np/_getlimits.py | 6 +- torch_np/_mapping.py | 130 ----------------------------------- torch_np/_ndarray.py | 76 +++++++++++++++------ torch_np/_normalizations.py | 8 +-- torch_np/_reductions.py | 133 ++++++++++-------------------------- torch_np/_util.py | 9 ++- 8 files changed, 133 insertions(+), 360 deletions(-) delete mode 100644 torch_np/_mapping.py diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 5fab34ef..fdb85503 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -19,15 +19,8 @@ if inspect.isfunction(getattr(_funcs_impl, x)) and not x.startswith("_") ] -# these implement ndarray methods but need not be public functions -semi_private = [ - "_flatten", - "_ndarray_resize", -] - - # decorate implementer functions with argument normalizers and export to the top namespace -for name in __all__ + semi_private: +for name in __all__: func = getattr(_funcs_impl, name) if name in ["percentile", "quantile", "median"]: decorated = normalizer(func, promote_scalar_result=True) diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index e1171391..14c506b1 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -8,7 +8,6 @@ from __future__ import annotations import builtins -import math import operator from typing import Optional, Sequence @@ -100,7 +99,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): # pure torch implementation, used below and in cov/corrcoef below - tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) + tensors, axis = _util.axis_none_flatten(*tensors, axis=axis) tensors = _concat_cast_helper(tensors, out, dtype, casting) return torch.cat(tensors, axis) @@ -881,7 +880,7 @@ def take( out: Optional[OutArray] = None, mode: NotImplementedType = "raise", ): - (a,), axis = _util.axis_none_ravel(a, axis=axis) + (a,), axis = _util.axis_none_flatten(a, axis=axis) axis = _util.normalize_axis_index(axis, a.ndim) idx = (slice(None),) * axis + (indices, ...) result = a[idx] @@ -889,13 +888,13 @@ def take( def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): - (arr,), axis = _util.axis_none_ravel(arr, axis=axis) + (arr,), axis = _util.axis_none_flatten(arr, axis=axis) axis = _util.normalize_axis_index(axis, arr.ndim) return torch.take_along_dim(arr, indices, axis) def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): - (arr,), axis = _util.axis_none_ravel(arr, axis=axis) + (arr,), axis = _util.axis_none_flatten(arr, axis=axis) axis = _util.normalize_axis_index(axis, arr.ndim) indices, values = torch.broadcast_tensors(indices, values) @@ -917,9 +916,7 @@ def unique( *, equal_nan: NotImplementedType = True, ): - if axis is None: - ar = ar.ravel() - axis = 0 + (ar,), axis = _util.axis_none_flatten(ar, axis=axis) axis = _util.normalize_axis_index(axis, ar.ndim) is_half = ar.dtype == torch.float16 @@ -948,7 +945,7 @@ def argwhere(a: ArrayLike): def flatnonzero(a: ArrayLike): - return torch.ravel(a).nonzero(as_tuple=True)[0] + return torch.flatten(a).nonzero(as_tuple=True)[0] def clip( @@ -980,7 +977,7 @@ def resize(a: ArrayLike, new_shape=None): if isinstance(new_shape, int): new_shape = (new_shape,) - a = ravel(a) + a = a.flatten() new_size = 1 for dim_length in new_shape: @@ -998,38 +995,6 @@ def resize(a: ArrayLike, new_shape=None): return reshape(a, new_shape) -def _ndarray_resize(a: ArrayLike, new_shape, refcheck=False): - # implementation of ndarray.resize. - # NB: differs from np.resize: fills with zeros instead of making repeated copies of input. - if refcheck: - raise NotImplementedError( - f"resize(..., refcheck={refcheck} is not implemented." - ) - - if new_shape in [(), (None,)]: - return a - - # support both x.resize((2, 2)) and x.resize(2, 2) - if len(new_shape) == 1: - new_shape = new_shape[0] - if isinstance(new_shape, int): - new_shape = (new_shape,) - - a = ravel(a) - - if builtins.any(x < 0 for x in new_shape): - raise ValueError("all elements of `new_shape` must be non-negative") - - new_numel = math.prod(new_shape) - if new_numel < a.numel(): - # shrink - return a[:new_numel].reshape(new_shape) - else: - b = torch.zeros(new_numel) - b[: a.numel()] = a - return b.reshape(new_shape) - - # ### diag et al ### @@ -1132,13 +1097,13 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False): def vdot(a: ArrayLike, b: ArrayLike, /): - # 1. torch only accepts 1D arrays, numpy ravels + # 1. torch only accepts 1D arrays, numpy flattens # 2. torch requires matching dtype, while numpy casts (?) t_a, t_b = torch.atleast_1d(a, b) if t_a.ndim > 1: - t_a = t_a.ravel() + t_a = t_a.flatten() if t_b.ndim > 1: - t_b = t_b.ravel() + t_b = t_b.flatten() dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype)) is_half = dtype == torch.float16 @@ -1212,7 +1177,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): def _sort_helper(tensor, axis, kind, order): - (tensor,), axis = _util.axis_none_ravel(tensor, axis=axis) + (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis) axis = _util.normalize_axis_index(axis, tensor.ndim) stable = kind == "stable" @@ -1328,14 +1293,6 @@ def transpose(a: ArrayLike, axes=None): def ravel(a: ArrayLike, order: NotImplementedType = "C"): - return torch.ravel(a) - - -# leading underscore since arr.flatten exists but np.flatten does not - - -def _flatten(a: ArrayLike, order: NotImplementedType = "C"): - # may return a copy return torch.flatten(a) @@ -1647,7 +1604,7 @@ def diff( def angle(z: ArrayLike, deg=False): result = torch.angle(z) if deg: - result = result * 180 / torch.pi + result = result * (180 / torch.pi) return result @@ -1658,26 +1615,14 @@ def sinc(x: ArrayLike): # ### Type/shape etc queries ### -def real(a: ArrayLike): - return torch.real(a) - - -def imag(a: ArrayLike): - if a.is_complex(): - result = a.imag - else: - result = torch.zeros_like(a) - return result - - def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): if a.is_floating_point(): result = torch.round(a, decimals=decimals) elif a.is_complex(): # RuntimeError: "round_cpu" not implemented for 'ComplexFloat' - result = ( - torch.round(a.real, decimals=decimals) - + torch.round(a.imag, decimals=decimals) * 1j + result = torch.complex( + torch.round(a.real, decimals=decimals), + torch.round(a.imag, decimals=decimals), ) else: # RuntimeError: "round_cpu" not implemented for 'int' @@ -1690,7 +1635,6 @@ def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): def real_if_close(a: ArrayLike, tol=100): - # XXX: copies vs views; numpy seems to return a copy? if not torch.is_complex(a): return a if tol > 1: @@ -1703,27 +1647,30 @@ def real_if_close(a: ArrayLike, tol=100): return a.real if mask.all() else a +def real(a: ArrayLike): + return torch.real(a) + + +def imag(a: ArrayLike): + if a.is_complex(): + return a.imag + return torch.zeros_like(a) + + def iscomplex(x: ArrayLike): if torch.is_complex(x): return x.imag != 0 - result = torch.zeros_like(x, dtype=torch.bool) - if result.ndim == 0: - result = result.item() - return result + return torch.zeros_like(x, dtype=torch.bool) def isreal(x: ArrayLike): if torch.is_complex(x): return x.imag == 0 - result = torch.ones_like(x, dtype=torch.bool) - if result.ndim == 0: - result = result.item() - return result + return torch.ones_like(x, dtype=torch.bool) def iscomplexobj(x: ArrayLike): - result = torch.is_complex(x) - return result + return torch.is_complex(x) def isrealobj(x: ArrayLike): @@ -1731,11 +1678,11 @@ def isrealobj(x: ArrayLike): def isneginf(x: ArrayLike, out: Optional[OutArray] = None): - return torch.isneginf(x, out=out) + return torch.isneginf(x) def isposinf(x: ArrayLike, out: Optional[OutArray] = None): - return torch.isposinf(x, out=out) + return torch.isposinf(x) def i0(x: ArrayLike): @@ -1743,7 +1690,6 @@ def i0(x: ArrayLike): def isscalar(a): - # XXX: this is a stub try: t = normalize_array_like(a) return t.numel() == 1 @@ -1798,8 +1744,6 @@ def bartlett(M): def common_type(*tensors: ArrayLike): - import builtins - is_complex = False precision = 0 for a in tensors: @@ -1836,7 +1780,7 @@ def histogram( is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex) is_w_int = weights is None or not weights.dtype.is_floating_point if is_a_int: - a = a.to(float) + a = a.double() if weights is not None: weights = _util.cast_if_needed(weights, a.dtype) @@ -1856,8 +1800,8 @@ def histogram( ) if not density and is_w_int: - h = h.to(int) + h = h.long() if is_a_int: - b = b.to(int) + b = b.long() return h, b diff --git a/torch_np/_getlimits.py b/torch_np/_getlimits.py index 229c8963..4f28306a 100644 --- a/torch_np/_getlimits.py +++ b/torch_np/_getlimits.py @@ -1,3 +1,5 @@ +import contextlib + import torch from . import _dtypes @@ -13,10 +15,6 @@ def iinfo(dtyp): return torch.iinfo(torch_dtype) -import contextlib - - -# FIXME: this is only a stub @contextlib.contextmanager def errstate(*args, **kwds): yield diff --git a/torch_np/_mapping.py b/torch_np/_mapping.py deleted file mode 100644 index a61db918..00000000 --- a/torch_np/_mapping.py +++ /dev/null @@ -1,130 +0,0 @@ -import numpy as np - -mapping = { - np.abs: abs, - np.absolute: absolute, - np.add: add, - np.angle: angle, - np.arccos: arccos, - np.arccosh: arccosh, - np.arcsin: arcsin, - np.arcsinh: arcsinh, - np.arctan: arctan, - np.arctan2: arctan2, - np.arctanh: arctanh, - np.argmax: argmax, - np.array: array, - np.asarray: asarray, - np.atleast_1d: atleast_1d, - np.atleast_2d: atleast_2d, - np.atleast_3d: atleast_3d, - np.bincount: bincount, - np.bitwise_and: bitwise_and, - np.bitwise_or: bitwise_or, - np.bitwise_xor: bitwise_xor, - np.broadcast_arrays: broadcast_arrays, - np.broadcast_shapes: broadcast_shapes, - np.broadcast_to: broadcast_to, - np.cbrt: cbrt, - np.ceil: ceil, - np.concatenate: concatenate, - np.conjugate: conjugate, - np.copy: copy, - np.copysign: copysign, - np.corrcoef: corrcoef, - np.cos: cos, - np.cosh: cosh, - np.deg2rad: deg2rad, - np.degrees: degrees, - np.divide: divide, - np.empty: empty, - np.empty_like: empty_like, - np.equal: equal, - np.exp: exp, - np.exp2: exp2, - np.expm1: expm1, - np.eye: eye, - np.fabs: fabs, - np.float_power: float_power, - np.floor: floor, - np.floor_divide: floor_divide, - np.fmax: fmax, - np.fmin: fmin, - np.fmod: fmod, - np.full: full, - np.full_like: full_like, - np.gcd: gcd, - np.greater: greater, - np.greater_equal: greater_equal, - np.heaviside: heaviside, - np.hypot: hypot, - np.i0: i0, - np.identity: identity, - np.imag: imag, - np.iscomplex: iscomplex, - np.iscomplexobj: iscomplexobj, - np.isfinite: isfinite, - np.isinf: isinf, - np.isnan: isnan, - np.isneginf: isneginf, - np.isposinf: isposinf, - np.isreal: isreal, - np.isrealobj: isrealobj, - np.lcm: lcm, - np.ldexp: ldexp, - np.left_shift: left_shift, - np.less: less, - np.less_equal: less_equal, - np.linspace: linspace, - np.log: log, - np.log10: log10, - np.log1p: log1p, - np.log2: log2, - np.logaddexp: logaddexp, - np.logaddexp2: logaddexp2, - np.logical_and: logical_and, - np.logical_not: logical_not, - np.logical_or: logical_or, - np.logical_xor: logical_xor, - np.matmul: matmul, - np.maximum: maximum, - np.minimum: minimum, - np.multiply: multiply, - np.ndim: ndim, - np.negative: negative, - np.nextafter: nextafter, - np.not_equal: not_equal, - np.np: np, - np.ones: ones, - np.ones_like: ones_like, - np.positive: positive, - np.power: power, - np.prod: prod, - np.rad2deg: rad2deg, - np.radians: radians, - np.ravel_multi_index: ravel_multi_index, - np.real: real, - np.real_if_close: real_if_close, - np.reciprocal: reciprocal, - np.remainder: remainder, - np.reshape: reshape, - np.right_shift: right_shift, - np.rint: rint, - np.shape: shape, - np.sign: sign, - np.signbit: signbit, - np.sin: sin, - np.sinh: sinh, - np.size: size, - np.sqrt: sqrt, - np.square: square, - np.squeeze: squeeze, - np.subtract: subtract, - np.tan: tan, - np.tanh: tanh, - np.torch: torch, - np.trunc: trunc, - np.unravel_index: unravel_index, - np.zeros: zeros, - np.zeros_like: zeros_like, -} diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 3b67afa6..2a363b6b 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -1,9 +1,13 @@ +from __future__ import annotations + +import builtins +import math import operator import torch from . import _dtypes, _dtypes_impl, _funcs, _funcs_impl, _helpers, _ufuncs, _util -from ._normalizations import ArrayLike, normalizer +from ._normalizations import ArrayLike, NotImplementedType, normalizer newaxis = None @@ -72,7 +76,6 @@ def f(*args, **kwargs): # If name_func == None, it means that name_method == name_func methods = { "clip": None, - "flatten": "_flatten", "nonzero": None, "repeat": None, "round": None, @@ -252,11 +255,47 @@ def astype(self, dtype): t = self.tensor.to(torch_dtype) return ndarray(t) - def copy(self, order="C"): - if order != "C": - raise NotImplementedError - tensor = self.tensor.clone() - return ndarray(tensor) + @normalizer + def copy(self: ArrayLike, order: NotImplementedType = "C"): + return self.clone() + + @normalizer + def flatten(self: ArrayLike, order: NotImplementedType = "C"): + return torch.flatten(self) + + def resize(self, *new_shape, refcheck=False): + a = self.tensor + # TODO(Lezcano) This is not done in-place + # implementation of ndarray.resize. + # NB: differs from np.resize: fills with zeros instead of making repeated copies of input. + if refcheck: + raise NotImplementedError( + f"resize(..., refcheck={refcheck} is not implemented." + ) + + if new_shape in [(), (None,)]: + return + + # support both x.resize((2, 2)) and x.resize(2, 2) + if len(new_shape) == 1: + new_shape = new_shape[0] + if isinstance(new_shape, int): + new_shape = (new_shape,) + + a = a.flatten() + + if builtins.any(x < 0 for x in new_shape): + raise ValueError("all elements of `new_shape` must be non-negative") + + new_numel = math.prod(new_shape) + if new_numel < a.numel(): + # shrink + ret = a[:new_numel].reshape(new_shape) + else: + b = torch.zeros(new_numel) + b[: a.numel()] = a + ret = b.reshape(new_shape) + self.tensor = ret def view(self, dtype): torch_dtype = _dtypes.dtype(dtype).torch_dtype @@ -327,11 +366,10 @@ def __complex__(self): def __int__(self): return int(self.tensor) - # XXX : are single-element ndarrays scalars? - # in numpy, only array scalars have the `is_integer` method def is_integer(self): try: - result = int(self.tensor) == self.tensor + v = self.tensor.item() + result = int(v) == v except Exception: result = False return result @@ -349,14 +387,6 @@ def reshape(self, *shape, order="C"): # arr.reshape(shape) and arr.reshape(*shape) return _funcs.reshape(self, shape, order=order) - def resize(self, *new_shape, refcheck=False): - # ndarray.resize works in-place (may cause a reallocation though) - self.tensor = _funcs_impl._ndarray_resize( - self.tensor, new_shape, refcheck=refcheck - ) - - ### sorting ### - def sort(self, axis=-1, kind=None, order=None): # ndarray.sort works in-place _funcs.copyto(self, _funcs.sort(self, axis, kind, order)) @@ -408,9 +438,13 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N raise NotImplementedError # a happy path - if isinstance(obj, ndarray): - if copy is False and dtype is None and ndmin <= obj.ndim: - return obj + if ( + isinstance(obj, ndarray) + and copy is False + and dtype is None + and ndmin <= obj.ndim + ): + return obj # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists if isinstance(obj, (list, tuple)): diff --git a/torch_np/_normalizations.py b/torch_np/_normalizations.py index ccfb6d0d..45735881 100644 --- a/torch_np/_normalizations.py +++ b/torch_np/_normalizations.py @@ -3,12 +3,13 @@ from __future__ import annotations import functools +import inspect import operator import typing import torch -from . import _helpers +from . import _dtypes ArrayLike = typing.TypeVar("ArrayLike") DTypeLike = typing.TypeVar("DTypeLike") @@ -34,11 +35,6 @@ NotImplementedType = typing.TypeVar("NotImplementedType") -import inspect - -from . import _dtypes - - def normalize_array_like(x, parm=None): from ._ndarray import asarray diff --git a/torch_np/_reductions.py b/torch_np/_reductions.py index b0307200..b5a7e053 100644 --- a/torch_np/_reductions.py +++ b/torch_np/_reductions.py @@ -5,29 +5,19 @@ """ import functools -import typing import torch from . import _dtypes_impl, _util -############# XXX -### From _util.axis_expand_func - def deco_axis_expand(func): """Generically handle axis arguments in reductions.""" @functools.wraps(func) - def wrapped(tensor, axis, *args, **kwds): + def wrapped(tensor, axis=None, *args, **kwds): if axis is not None: - if not isinstance(axis, (list, tuple)): - if not isinstance(axis, typing.SupportsIndex): - raise TypeError( - f"{type(axis)=}, but should be a list/tuple or support operator.index()" - ) - axis = (axis,) axis = _util.normalize_axis_tuple(axis, tensor.ndim) if axis == (): @@ -39,7 +29,7 @@ def wrapped(tensor, axis, *args, **kwds): tensor = tensor.reshape(newshape) axis = (0,) - result = func(tensor, axis=axis, *args, **kwds) + result = func(tensor, axis, *args, **kwds) return result return wrapped @@ -48,7 +38,7 @@ def wrapped(tensor, axis, *args, **kwds): def emulate_keepdims(func): @functools.wraps(func) def wrapped(tensor, axis=None, keepdims=None, *args, **kwds): - result = func(tensor, axis=axis, *args, **kwds) + result = func(tensor, axis, *args, **kwds) if keepdims: result = _util.apply_keepdims(result, axis, tensor.ndim) return result @@ -56,26 +46,6 @@ def wrapped(tensor, axis=None, keepdims=None, *args, **kwds): return wrapped -def deco_axis_ravel(func): - """Generically handle 'axis=None ravels' behavior.""" - - @functools.wraps(func) - def wrapped(tensor, axis, *args, **kwds): - if axis is not None: - axis = _util.normalize_axis_index(axis, tensor.ndim) - - tensors, axis = _util.axis_none_ravel(tensor, axis=axis) # XXX: inline - tensor = tensors[0] - - result = func(tensor, axis=axis, *args, **kwds) - return result - - return wrapped - - -##################################3 - - def _atleast_float(dtype, other_dtype): """Return a dtype that is real or complex floating-point. @@ -93,12 +63,7 @@ def _atleast_float(dtype, other_dtype): @emulate_keepdims @deco_axis_expand def count_nonzero(a, axis=None): - # XXX: this all should probably be generalized to a sum(a != 0, dtype=bool) - try: - return a.count_nonzero(axis) - except RuntimeError: - raise ValueError - return tensor + return a.count_nonzero(axis) @emulate_keepdims @@ -110,8 +75,7 @@ def argmax(tensor, axis=None): # RuntimeError: "argmax_cpu" not implemented for 'Bool' tensor = tensor.to(torch.uint8) - tensor = torch.argmax(tensor, axis) - return tensor + return torch.argmax(tensor, axis) @emulate_keepdims @@ -123,32 +87,23 @@ def argmin(tensor, axis=None): # RuntimeError: "argmin_cpu" not implemented for 'Bool' tensor = tensor.to(torch.uint8) - tensor = torch.argmin(tensor, axis) - return tensor + return torch.argmin(tensor, axis) @emulate_keepdims @deco_axis_expand def any(tensor, axis=None, *, where=None): axis = _util.allow_only_single_axis(axis) - - if axis is None: - result = tensor.any() - else: - result = tensor.any(axis) - return result + axis_kw = {} if axis is None else {"dim": axis} + return torch.any(tensor, **axis_kw) @emulate_keepdims @deco_axis_expand def all(tensor, axis=None, *, where=None): axis = _util.allow_only_single_axis(axis) - - if axis is None: - result = tensor.all() - else: - result = tensor.all(axis) - return result + axis_kw = {} if axis is None else {"dim": axis} + return torch.all(tensor, **axis_kw) @emulate_keepdims @@ -227,9 +182,7 @@ def mean(tensor, axis=None, dtype=None, *, where=None): def std(tensor, axis=None, dtype=None, ddof=0, *, where=None): dtype = _atleast_float(dtype, tensor.dtype) tensor = _util.cast_if_needed(tensor, dtype) - result = tensor.std(dim=axis, correction=ddof) - - return result + return tensor.std(dim=axis, correction=ddof) @emulate_keepdims @@ -237,38 +190,36 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=None): def var(tensor, axis=None, dtype=None, ddof=0, *, where=None): dtype = _atleast_float(dtype, tensor.dtype) tensor = _util.cast_if_needed(tensor, dtype) - result = tensor.var(dim=axis, correction=ddof) - - return result + return tensor.var(dim=axis, correction=ddof) # cumsum / cumprod are almost reductions: # 1. no keepdims -# 2. axis=None ravels (cf concatenate) +# 2. axis=None flattens -@deco_axis_ravel def cumprod(tensor, axis, dtype=None): if dtype == torch.bool: dtype = _dtypes_impl.default_dtypes.int_dtype if dtype is None: dtype = tensor.dtype - result = tensor.cumprod(axis=axis, dtype=dtype) + (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis) + axis = _util.normalize_axis_index(axis, tensor.ndim) - return result + return tensor.cumprod(axis=axis, dtype=dtype) -@deco_axis_ravel def cumsum(tensor, axis, dtype=None): if dtype == torch.bool: dtype = _dtypes_impl.default_dtypes.int_dtype if dtype is None: dtype = tensor.dtype - result = tensor.cumsum(axis=axis, dtype=dtype) + (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis) + axis = _util.normalize_axis_index(axis, tensor.ndim) - return result + return tensor.cumsum(axis=axis, dtype=dtype) def average(a, axis, weights, returned=False, keepdims=False): @@ -277,9 +228,8 @@ def average(a, axis, weights, returned=False, keepdims=False): else: result, wsum = average_weights(a, axis, weights, keepdims=keepdims) - if returned: - if wsum.shape != result.shape: - wsum = torch.broadcast_to(wsum, result.shape).clone() + if returned and wsum.shape != result.shape: + wsum = torch.broadcast_to(wsum, result.shape).clone() return result, wsum @@ -290,26 +240,11 @@ def average_noweights(a, axis, keepdims=False): def average_weights(a, axis, w, keepdims=False): - - # dtype - # FIXME: 1. use result_type - # 2. actually implement multiply w/dtype if not a.dtype.is_floating_point: - result_dtype = torch.float64 - a = a.to(result_dtype) + a = a.double() result_dtype = _dtypes_impl.result_type_impl([a.dtype, w.dtype]) - a = _util.cast_if_needed(a, result_dtype) - w = _util.cast_if_needed(w, result_dtype) - - # axis=None ravels, so store the originals to reuse with keepdims=True below - ax, ndim = axis, a.ndim - - # axis - if axis is None: - (a, w), axis = _util.axis_none_ravel(a, w, axis=axis) - # axis & weights if a.shape != w.shape: if axis is None: @@ -326,13 +261,13 @@ def average_weights(a, axis, w, keepdims=False): w = w.swapaxes(-1, axis) # do the work - numerator = torch.mul(a, w).sum(axis) - denominator = w.sum(axis) + numerator = torch.mul(a, w).sum(axis, dtype=result_dtype) + denominator = w.sum(axis, dtype=result_dtype) result = numerator / denominator # keepdims if keepdims: - result = _util.apply_keepdims(result, ax, ndim) + result = _util.apply_keepdims(result, axis, a.ndim) return result, denominator @@ -361,18 +296,22 @@ def quantile( if a.dtype == torch.float16: a = a.to(torch.float32) - # TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all. - # axis - if axis is not None: + # axis=None flattens, so store the originals to reuse with keepdims=True below + ax, ndim = axis, a.ndim + if axis is None: + a = a.flatten() + q = q.flatten() + axis = (0,) + else: axis = _util.normalize_axis_tuple(axis, a.ndim) + + # FIXME(Mario) Doesn't np.quantile accept a tuple? + # torch.quantile does accept a number. If we don't want to implement the tuple behaviour + # (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above. axis = _util.allow_only_single_axis(axis) q = _util.cast_if_needed(q, a.dtype) - # axis=None ravels, so store the originals to reuse with keepdims=True below - ax, ndim = axis, a.ndim - (a, q), axis = _util.axis_none_ravel(a, q, axis=axis) - result = torch.quantile(a, q, axis=axis, interpolation=method) # NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...) diff --git a/torch_np/_util.py b/torch_np/_util.py index 8d4c0be2..8c56cc90 100644 --- a/torch_np/_util.py +++ b/torch_np/_util.py @@ -119,18 +119,17 @@ def apply_keepdims(tensor, axis, ndim): if axis is None: # tensor was a scalar shape = (1,) * ndim - tensor = tensor.expand(shape).contiguous() # avoid CUDA synchronization + tensor = tensor.expand(shape).contiguous() else: shape = expand_shape(tensor.shape, axis) tensor = tensor.reshape(shape) return tensor -def axis_none_ravel(*tensors, axis=None): - """Ravel the arrays if axis is none.""" - # XXX: is only used at `concatenate` and cumsum/cumprod. Inline unless reused more widely +def axis_none_flatten(*tensors, axis=None): + """Flatten the arrays if axis is None else normalize the axis.""" if axis is None: - tensors = tuple(ar.ravel() for ar in tensors) + tensors = tuple(ar.flatten() for ar in tensors) return tensors, 0 else: return tensors, axis From c3a9b7eef2613fc57b8ecd6cb40660764ca3c4b4 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 27 Apr 2023 08:37:50 +0000 Subject: [PATCH 2/2] revert comment --- torch_np/_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_np/_util.py b/torch_np/_util.py index 8c56cc90..e120898c 100644 --- a/torch_np/_util.py +++ b/torch_np/_util.py @@ -127,7 +127,7 @@ def apply_keepdims(tensor, axis, ndim): def axis_none_flatten(*tensors, axis=None): - """Flatten the arrays if axis is None else normalize the axis.""" + """Flatten the arrays if axis is None.""" if axis is None: tensors = tuple(ar.flatten() for ar in tensors) return tensors, 0