From c44900f8a758da99d40a4cb85cfa30ea78fff44f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 11 Feb 2023 18:10:16 +0300 Subject: [PATCH 01/20] MAINT: split cov, corrcoef, concat into ndarray/torch parts --- torch_np/_decorators.py | 2 +- torch_np/_detail/implementations.py | 77 +++++++++++++++++++++++++++++ torch_np/_wrapper.py | 70 ++++---------------------- 3 files changed, 88 insertions(+), 61 deletions(-) diff --git a/torch_np/_decorators.py b/torch_np/_decorators.py index 4679b92f..37c98bf6 100644 --- a/torch_np/_decorators.py +++ b/torch_np/_decorators.py @@ -15,7 +15,7 @@ def wrapped(*args, dtype=None, **kwds): torch_dtype = None if dtype is not None: dtype = _dtypes.dtype(dtype) - torch_dtype = dtype._scalar_type.torch_dtype + torch_dtype = dtype.torch_dtype return func(*args, dtype=torch_dtype, **kwds) return wrapped diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index d9c59bfb..76b585c0 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -95,6 +95,82 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None): return result +# #### concatenate and relatives + +def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): + # np.concatenate ravels if axis=None + tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) + + # figure out the type of the inputs and outputs + if out is None and dtype is None: + out_dtype = None + else: + out_dtype = out.dtype.torch_dtype if dtype is None else dtype + + # cast input arrays if necessary; do not broadcast them agains `out` + tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting) + + try: + result = torch.cat(tensors, axis) + except (IndexError, RuntimeError): + raise _util.AxisError + + return result + + +# #### cov & corrcoef + +def corrcoef(xy_tensor, rowvar=True, *, dtype=None): + if rowvar is False: + # xy_tensor is at least 2D, so using .T is safe + xy_tensor = x_tensor.T + + is_half = dtype == torch.float16 + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + + if dtype is not None: + xy_tensor = xy_tensor.to(dtype) + + result = torch.corrcoef(xy_tensor) + + if is_half: + result = result.to(torch.float16) + + return result + + +def cov( + m_tensor, + bias=False, + ddof=None, + fweights_tensor=None, + aweights_tensor=None, + *, + dtype=None, +): + if ddof is None: + ddof = 1 if bias == 0 else 0 + + is_half = dtype == torch.float16 + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + + if dtype is not None: + m_tensor = m_tensor.to(dtype) + + result = torch.cov( + m_tensor, correction=ddof, aweights=aweights_tensor, fweights=fweights_tensor + ) + + if is_half: + result = result.to(torch.float16) + + return result + + def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"): # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L4892-L5047 ndim = len(xi_tensors) @@ -118,3 +194,4 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"): output = [x.clone() for x in output] return output + diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index e4521b01..d27fa320 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -434,23 +434,7 @@ def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None): x = concatenate((x, y), axis=0) x_tensor = asarray(x).get() - - if rowvar is False: - x_tensor = x_tensor.T - - is_half = dtype == torch.float16 - if is_half: - # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" - dtype = torch.float32 - - if dtype is not None: - x_tensor = x_tensor.to(dtype) - - result = torch.corrcoef(x_tensor) - - if is_half: - result = result.to(torch.float16) - + result = _impl.corrcoef(x_tensor, rowvar, dtype=dtype) return asarray(result) @@ -478,45 +462,25 @@ def cov( m = concatenate((m, y), axis=0) - if ddof is None: - if bias == 0: - ddof = 1 - else: - ddof = 0 +# if ddof is None: +# if bias == 0: +# ddof = 1 +# else: +# ddof = 0 m_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none( m, fweights, aweights ) - - # work with tensors from now on - is_half = dtype == torch.float16 - if is_half: - # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" - dtype = torch.float32 - - if dtype is not None: - m_tensor = m_tensor.to(dtype) - - result = torch.cov( - m_tensor, correction=ddof, aweights=aweights_tensor, fweights=fweights_tensor - ) - - if is_half: - result = result.to(torch.float16) - + result = _impl.cov(m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype) return asarray(result) +@_decorators.dtype_to_torch def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"): if ar_tuple == (): # XXX: RuntimeError in torch, ValueError in numpy raise ValueError("need at least one array to concatenate") - tensors = _helpers.to_tensors(*ar_tuple) - - # np.concatenate ravels if axis=None - tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) - if out is not None: if not isinstance(out, ndarray): raise ValueError("'out' must be an array") @@ -527,22 +491,8 @@ def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"): "concatenate() only takes `out` or `dtype` as an " "argument, but both were provided." ) - - # figure out the type of the inputs and outputs - if out is None and dtype is None: - out_dtype = None - else: - out_dtype = out.dtype if dtype is None else _dtypes.dtype(dtype) - out_dtype = out_dtype.type.torch_dtype - - # cast input arrays if necessary; do not broadcast them agains `out` - tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting) - - try: - result = torch.cat(tensors, axis) - except (IndexError, RuntimeError): - raise _util.AxisError - + tensors = _helpers.to_tensors(*ar_tuple) + result = _impl.concatenate(tensors, axis, out, dtype, casting) return _helpers.result_or_out(result, out) From ebb623ed05f31ca648d328536721ebeaa53189bf Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 13 Feb 2023 20:51:10 +0300 Subject: [PATCH 02/20] MAINT: trivially simplify concat --- torch_np/_detail/implementations.py | 11 +++++------ torch_np/_wrapper.py | 10 +++------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 76b585c0..194b2fd3 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -97,14 +97,13 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None): # #### concatenate and relatives + def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): # np.concatenate ravels if axis=None tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) - # figure out the type of the inputs and outputs - if out is None and dtype is None: - out_dtype = None - else: + if out is not None or dtype is not None: + # figure out the type of the inputs and outputs out_dtype = out.dtype.torch_dtype if dtype is None else dtype # cast input arrays if necessary; do not broadcast them agains `out` @@ -120,7 +119,8 @@ def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): # #### cov & corrcoef -def corrcoef(xy_tensor, rowvar=True, *, dtype=None): + +def corrcoef(xy_tensor, rowvar=True, *, dtype=None): if rowvar is False: # xy_tensor is at least 2D, so using .T is safe xy_tensor = x_tensor.T @@ -194,4 +194,3 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"): output = [x.clone() for x in output] return output - diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index d27fa320..6ffe448b 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -462,16 +462,12 @@ def cov( m = concatenate((m, y), axis=0) -# if ddof is None: -# if bias == 0: -# ddof = 1 -# else: -# ddof = 0 - m_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none( m, fweights, aweights ) - result = _impl.cov(m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype) + result = _impl.cov( + m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype + ) return asarray(result) From 2bce88bb58479e1cef80afbf980d23b5728b4096 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 13 Feb 2023 22:07:25 +0300 Subject: [PATCH 03/20] MAINT: split clip --- torch_np/_detail/implementations.py | 14 ++++++++++++++ torch_np/_ndarray.py | 21 +++------------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 194b2fd3..bda8bbb1 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -72,6 +72,20 @@ def split_helper_list(tensor, indices_or_sections, axis, strict=False): return torch.split(tensor, lst, axis) +def clip(tensor, t_min, t_max): + if t_min is not None: + t_min = torch.broadcast_to(t_min, tensor.shape) + + if t_max is not None: + t_max = torch.broadcast_to(t_max, tensor.shape) + + if t_min is None and t_max is None: + raise ValueError("One of max or min must be given") + + result = tensor.clamp(t_min, t_max) + return result + + def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None): axis = _util.normalize_axis_index(axis, a_tensor.ndim) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index befa3a9e..7b95f15a 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -12,6 +12,7 @@ emulate_out_arg, ) from ._detail import _dtypes_impl, _flips, _reductions, _util +from ._detail import implementations as _impl newaxis = None @@ -367,24 +368,8 @@ def nonzero(self): return tuple(asarray(_) for _ in tensor.nonzero(as_tuple=True)) def clip(self, min, max, out=None): - tensor = self._tensor - a_min, a_max = min, max - - t_min = None - if a_min is not None: - t_min = asarray(a_min).get() - t_min = torch.broadcast_to(t_min, tensor.shape) - - t_max = None - if a_max is not None: - t_max = asarray(a_max).get() - t_max = torch.broadcast_to(t_max, tensor.shape) - - if t_min is None and t_max is None: - raise ValueError("One of max or min must be given") - - result = tensor.clamp(t_min, t_max) - + tensor, t_min, t_max = _helpers.to_tensors_or_none(self, min, max) + result = _impl.clip(tensor, t_min, t_max) return _helpers.result_or_out(result, out) argmin = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmin)) From 5bee103f467f42bf0a2cec0f114f4fa175d1fe5d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 14 Feb 2023 20:07:41 +0300 Subject: [PATCH 04/20] MAINT: split _equal, _isclose etc --- torch_np/_detail/implementations.py | 25 ++++++++++++++++++++++++- torch_np/_wrapper.py | 22 ++++++++-------------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index bda8bbb1..2af33e6f 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -1,6 +1,8 @@ import torch -from . import _util +from . import _dtypes_impl, _util + +# ### equality, equivalence, allclose ### def tensor_equal(a1_t, a2_t, equal_nan=False): @@ -19,6 +21,27 @@ def tensor_equal(a1_t, a2_t, equal_nan=False): return bool(result.all()) +def tensor_equiv(a1_t, a2_t): + # *almost* the same as tensor_equal: _equiv tries to broadcast, _equal does not + try: + a1_t, a2_t = torch.broadcast_tensors(a1_t, a2_t) + except RuntimeError: + # failed to broadcast => not equivalent + return False + return tensor_equal(a1_t, a2_t) + + +def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): + dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + a = a.to(dtype) + b = b.to(dtype) + result = torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + return result + + +# ### splits ### + + def split_helper(tensor, indices_or_sections, axis, strict=False): if isinstance(indices_or_sections, int): return split_helper_int(tensor, indices_or_sections, axis, strict) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 6ffe448b..8f0a4b2b 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -1107,17 +1107,15 @@ def isscalar(a): def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): - a, b = _helpers.to_tensors(a, b) - dtype = result_type(a, b) - torch_dtype = dtype.type.torch_dtype - a = a.to(torch_dtype) - b = b.to(torch_dtype) - return asarray(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) + a_t, b_t = _helpers.to_tensors(a, b) + result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) + return asarray(result) def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - arr_res = isclose(a, b, rtol, atol, equal_nan) - return arr_res.all() + a_t, b_t = _helpers.to_tensors(a, b) + result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) + return result.all() def array_equal(a1, a2, equal_nan=False): @@ -1128,12 +1126,8 @@ def array_equal(a1, a2, equal_nan=False): def array_equiv(a1, a2): a1_t, a2_t = _helpers.to_tensors(a1, a2) - try: - a1_t, a2_t = torch.broadcast_tensors(a1_t, a2_t) - except RuntimeError: - # failed to broadcast => not equivalent - return False - return _impl.tensor_equal(a1_t, a2_t) + result = _impl.tensor_equiv(a1_t, a2_t) + return result def common_type(): From 917cf39d8fe1239e6642d64967728d759c44253d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 14 Feb 2023 20:32:16 +0300 Subject: [PATCH 05/20] MAINT: isreal, iscomplex, real_if_close, angle --- torch_np/_detail/implementations.py | 41 +++++++++++++++++++++++++++++ torch_np/_wrapper.py | 26 ++++++------------ 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 2af33e6f..a0fa47fe 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -39,6 +39,47 @@ def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): return result +# ### is arg real or complex valued ### + + +def tensor_iscomplex(x): + if torch.is_complex(x): + return torch.as_tensor(x).imag != 0 + result = torch.zeros_like(x, dtype=torch.bool) + if result.ndim == 0: + result = result.item() + return result + + +def tensor_isreal(x): + if torch.is_complex(x): + return torch.as_tensor(x).imag == 0 + result = torch.zeros_like(x, dtype=torch.bool) + if result.ndim == 0: + result = result.item() + return result + + +def tensor_real_if_close(x, tol=100): + if not torch.is_complex(x): + return x + mask = torch.abs(x.imag) < tol * torch.finfo(x.dtype).eps + if mask.all(): + return x.real + else: + return x + + +# ### math functions ### + + +def tensor_angle(z, deg=False): + result = torch.angle(z) + if deg: + result *= 180 / torch.pi + return result + + # ### splits ### diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 8f0a4b2b..1de6c2c9 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -1025,10 +1025,8 @@ def argsort(a, axis=-1, kind=None, order=None): @asarray_replacer() def angle(z, deg=False): - result = torch.angle(z) - if deg: - result *= 180 / torch.pi - return asarray(result) + result = _impl.tensor_angle(z, deg) + return result @asarray_replacer() @@ -1048,28 +1046,20 @@ def imag(a): @asarray_replacer() def real_if_close(a, tol=100): - if not torch.is_complex(a): - return a - if torch.abs(torch.imag) < tol * torch.finfo(a.dtype).eps: - return torch.real(a) - else: - return a + result = _impl.tensor_real_if_close(a, tol=tol) + return result @asarray_replacer() def iscomplex(x): - if torch.is_complex(x): - return torch.as_tensor(x).imag != 0 - result = torch.zeros_like(x, dtype=torch.bool) - return result[()] + result = _impl.tensor_iscomplex(x) + return result # XXX: missing .item on a zero-dim value; a case for array_or_scalar(value) ? @asarray_replacer() def isreal(x): - if torch.is_complex(x): - return torch.as_tensor(x).imag == 0 - result = torch.zeros_like(x, dtype=torch.bool) - return result[()] + result = _impl.tensor_isreal(x) + return result @asarray_replacer() From 5031d40f4eac57c24882ed81f6e2a7c4a13abf8c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 14 Feb 2023 21:01:49 +0300 Subject: [PATCH 06/20] MAINT: split argsort --- torch_np/_detail/implementations.py | 11 +++++++++++ torch_np/_wrapper.py | 9 ++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index a0fa47fe..dade344a 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -80,6 +80,17 @@ def tensor_angle(z, deg=False): return result +# ### sorting ### + +def tensor_argsort(tensor, axis=-1, kind=None, order=None): + if order is not None: + raise NotImplementedError + stable = True if kind == "stable" else False + if axis is None: + axis = -1 + return torch.argsort(tensor, stable=stable, dim=axis, descending=False) + + # ### splits ### diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 1de6c2c9..7e8f13dd 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -705,7 +705,6 @@ def triu(m, k=0): return m.triu(k) -# YYY: pattern: return sequence def tril_indices(n, k=0, m=None): if m is None: m = n @@ -1012,12 +1011,8 @@ def diff(a, n=1, axis=-1, prepend=NoValue, append=NoValue): @asarray_replacer() def argsort(a, axis=-1, kind=None, order=None): - if order is not None: - raise NotImplementedError - stable = True if kind == "stable" else False - if axis is None: - axis = -1 - return torch.argsort(a, stable=stable, dim=axis, descending=False) + result = _impl.tensor_argsort(a, axis, kind, order) + return result ##### math functions From b9a8ab57389f1b888624bf389411425bc2bdfe5c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 14 Feb 2023 21:07:35 +0300 Subject: [PATCH 07/20] MAINT: split tri* family --- torch_np/_detail/implementations.py | 40 +++++++++++++++++++++++++++++ torch_np/_wrapper.py | 39 ++++++++++------------------ 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index dade344a..808f4904 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -82,6 +82,7 @@ def tensor_angle(z, deg=False): # ### sorting ### + def tensor_argsort(tensor, axis=-1, kind=None, order=None): if order is not None: raise NotImplementedError @@ -91,6 +92,45 @@ def tensor_argsort(tensor, axis=-1, kind=None, order=None): return torch.argsort(tensor, stable=stable, dim=axis, descending=False) +# ### tri*-something ### + + +def tri(N, M, k, dtype): + if M is None: + M = N + tensor = torch.ones((N, M), dtype=dtype) + tensor = torch.tril(tensor, diagonal=k) + return tensor + + +def triu_indices_from(tensor, k): + if tensor.ndim != 2: + raise ValueError("input array must be 2-d") + result = torch.triu_indices(tensor.shape[0], tensor.shape[1], offset=k) + return result + + +def tril_indices_from(tensor, k=0): + if tensor.ndim != 2: + raise ValueError("input array must be 2-d") + result = torch.tril_indices(tensor.shape[0], tensor.shape[1], offset=k) + return result + + +def tril_indices(n, k=0, m=None): + if m is None: + m = n + result = torch.tril_indices(n, m, offset=k) + return result + + +def triu_indices(n, k=0, m=None): + if m is None: + m = n + result = torch.triu_indices(n, m, offset=k) + return result + + # ### splits ### diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 7e8f13dd..9aabd863 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -601,7 +601,8 @@ def broadcast_to(array, shape, subok=False): # YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero def broadcast_arrays(*args, subok=False): _util.subok_not_ok(subok=subok) - res = torch.broadcast_tensors(*[asarray(a).get() for a in args]) + tensors = _helpers.to_tensors(*args) + res = torch.broadcast_tensors(*tensors) return tuple(asarray(_) for _ in res) @@ -706,44 +707,32 @@ def triu(m, k=0): def tril_indices(n, k=0, m=None): - if m is None: - m = n - tensor_2 = torch.tril_indices(n, m, offset=k) - return tuple(asarray(_) for _ in tensor_2) + result = _impl.tril_indices(n, k, m) + return tuple(asarray(t) for t in result) def triu_indices(n, k=0, m=None): - if m is None: - m = n - tensor_2 = torch.tril_indices(n, m, offset=k) - return tuple(asarray(_) for _ in tensor_2) + result = _impl.triu_indices(n, k, m) + return tuple(asarray(t) for t in result) -# YYY: pattern: array in, sequence of arrays out def tril_indices_from(arr, k=0): - arr = asarray(arr).get() - if arr.ndim != 2: - raise ValueError("input array must be 2-d") - tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) - return tuple(asarray(_) for _ in tensor_2) + tensor = asarray(arr).get() + result = _impl.tril_indices_from(tensor, k) + return tuple(asarray(t) for t in result) def triu_indices_from(arr, k=0): - arr = asarray(arr).get() - if arr.ndim != 2: - raise ValueError("input array must be 2-d") - tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) - return tuple(asarray(_) for _ in tensor_2) + tensor = asarray(arr).get() + result = _impl.triu_indices_from(tensor, k) + return tuple(asarray(t) for t in result) @_decorators.dtype_to_torch def tri(N, M=None, k=0, dtype=float, *, like=None): _util.subok_not_ok(like) - if M is None: - M = N - tensor = torch.ones((N, M), dtype=dtype) - tensor = torch.tril(tensor, diagonal=k) - return asarray(tensor) + result = _impl.tri(N, M, k, dtype) + return asarray(result) ###### reductions From 9e2e9f272f0fd6e477899ee680edfc0b4b96b1bc Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 14 Feb 2023 23:38:31 +0300 Subject: [PATCH 08/20] MAINT: split bincount --- torch_np/_detail/implementations.py | 9 +++++++++ torch_np/_wrapper.py | 5 +---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 808f4904..425b11ee 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -323,3 +323,12 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"): output = [x.clone() for x in output] return output + + + +def bincount(x_tensor, /, weights_tensor=None, minlength=0): + int_dtype = _dtypes_impl.default_int_dtype + (x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe") + + result = torch.bincount(x_tensor, weights_tensor, minlength) + return result diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 9aabd863..0cc2b933 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -498,10 +498,7 @@ def bincount(x, /, weights=None, minlength=0): x = asarray([], dtype=int) x_tensor, weights_tensor = _helpers.to_tensors_or_none(x, weights) - int_dtype = _dtypes_impl.default_int_dtype - (x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe") - - result = torch.bincount(x_tensor, weights_tensor, minlength) + result = _impl.bincount(x_tensor, weights_tensor, minlength) return asarray(result) From b54866e8eeea55f8cb092d24fb940069a663a7da Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 16 Feb 2023 13:59:55 +0300 Subject: [PATCH 09/20] MAINT: split *stack/concat family --- torch_np/_detail/implementations.py | 49 ++++++++++++ torch_np/_wrapper.py | 119 +++++++++++++--------------- 2 files changed, 105 insertions(+), 63 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 425b11ee..9bc34521 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -246,6 +246,55 @@ def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): return result +def stack(tensors, axis=0, out=None, *, dtype=None, casting="same_kind"): + shapes = {t.shape for t in tensors} + if len(shapes) != 1: + raise ValueError("all input arrays must have the same shape") + + result_ndim = tensors[0].ndim + 1 + axis = _util.normalize_axis_index(axis, result_ndim) + + sl = (slice(None),) * axis + (None,) + expanded_tensors = [tensor[sl] for tensor in tensors] + result = concatenate(expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting) + + return result + + +def column_stack(tensors_, *, dtype=None, casting="same_kind"): + tensors = [] + for t in tensors_: + if t.ndim < 2: + t = _util._coerce_to_tensor(t, copy=False, ndmin=2).mT + tensors.append(t) + + result = concatenate(tensors, 1, dtype=dtype, casting=casting) + return result + + +def dstack(tensors, *, dtype=None, casting="same_kind"): + tensors = torch.atleast_3d(tensors) + result = concatenate(tensors, 2, dtype=dtype, casting=casting) + return result + + +def hstack(tensors, *, dtype=None, casting="same_kind"): + tensors = torch.atleast_1d(tensors) + + # As a special case, dimension 0 of 1-dimensional arrays is "horizontal"s + if tensors and tensors[0].ndim == 1: + result = concatenate(tensors, 0, dtype=dtype, casting=casting) + else: + result = concatenate(tensors, 1, dtype=dtype, casting=casting) + return result + + +def vstack(tensors, *, dtype=None, casting="same_kind"): + tensors = torch.atleast_2d(tensors) + result = concatenate(tensors, 0, dtype=dtype, casting=casting) + return result + + # #### cov & corrcoef diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 0cc2b933..a99a7c29 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -74,7 +74,8 @@ def copy(a, order="K", subok=False): def atleast_1d(*arys): - res = torch.atleast_1d([asarray(a).get() for a in arys]) + tensors = _helpers.to_tensors(*arys) + res = torch.atleast_1d(tensors) if len(res) == 1: return asarray(res[0]) else: @@ -82,7 +83,8 @@ def atleast_1d(*arys): def atleast_2d(*arys): - res = torch.atleast_2d([asarray(a).get() for a in arys]) + tensors = _helpers.to_tensors(*arys) + res = torch.atleast_2d(tensors) if len(res) == 1: return asarray(res[0]) else: @@ -90,73 +92,85 @@ def atleast_2d(*arys): def atleast_3d(*arys): - res = torch.atleast_3d([asarray(a).get() for a in arys]) + tensors = _helpers.to_tensors(*arys) + res = torch.atleast_3d(tensors) if len(res) == 1: return asarray(res[0]) else: return list(asarray(_) for _ in res) +def _concat_check(tup, dtype, out): + """Check inputs in concatenate et al.""" + if tup == (): + # XXX: RuntimeError in torch, ValueError in numpy + raise ValueError("need at least one array to concatenate") + + if out is not None: + if not isinstance(out, ndarray): + raise ValueError("'out' must be an array") + + if dtype is not None: + # mimic numpy + raise TypeError( + "concatenate() only takes `out` or `dtype` as an " + "argument, but both were provided." + ) + + +@_decorators.dtype_to_torch +def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"): + _concat_check(ar_tuple, dtype, out=out) + tensors = _helpers.to_tensors(*ar_tuple) + result = _impl.concatenate(tensors, axis, out, dtype, casting) + return _helpers.result_or_out(result, out) + + +@_decorators.dtype_to_torch def vstack(tup, *, dtype=None, casting="same_kind"): - arrs = atleast_2d(*tup) - if not isinstance(arrs, list): - arrs = [arrs] - return concatenate(arrs, 0, dtype=dtype, casting=casting) + tensors = _helpers.to_tensors(*tup) + _concat_check(tensors, dtype, out=None) + result = _impl.vstack(tensors, dtype=dtype, casting=casting) + return asarray(result) row_stack = vstack +@_decorators.dtype_to_torch def hstack(tup, *, dtype=None, casting="same_kind"): - arrs = atleast_1d(*tup) - if not isinstance(arrs, list): - arrs = [arrs] - # As a special case, dimension 0 of 1-dimensional arrays is "horizontal" - if arrs and arrs[0].ndim == 1: - return concatenate(arrs, 0, dtype=dtype, casting=casting) - else: - return concatenate(arrs, 1, dtype=dtype, casting=casting) + tensors = _helpers.to_tensors(*tup) + _concat_check(tensors, dtype, out=None) + result = _impl.hstack(tensors, dtype=dtype, casting=casting) + return asarray(result) +@_decorators.dtype_to_torch def dstack(tup, *, dtype=None, casting="same_kind"): # XXX: in numpy 1.24 dstack does not have dtype and casting keywords # but {h,v}stack do. Hence add them here for consistency. - arrs = atleast_3d(*tup) - if not isinstance(arrs, list): - arrs = [arrs] - return concatenate(arrs, 2, dtype=dtype, casting=casting) + tensors = _helpers.to_tensors(*tup) + result = _impl.dstack(tensors, dtype=dtype, casting=casting) + return asarray(result) +@_decorators.dtype_to_torch def column_stack(tup, *, dtype=None, casting="same_kind"): # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords # but row_stack does. (because row_stack is an alias for vstack, really). # Hence add these keywords here for consistency. - arrays = [] - for v in tup: - arr = asarray(v) - if arr.ndim < 2: - arr = array(arr, copy=False, ndmin=2).T - arrays.append(arr) - return concatenate(arrays, 1, dtype=dtype, casting=casting) + tensors = _helpers.to_tensors(*tup) + _concat_check(tensors, dtype, out=None) + result = _impl.column_stack(tensors, dtype=dtype, casting=casting) + return asarray(result) +@_decorators.dtype_to_torch def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"): - arrays = [asarray(arr) for arr in arrays] - if not arrays: - raise ValueError("need at least one array to stack") - - shapes = {arr.shape for arr in arrays} - if len(shapes) != 1: - raise ValueError("all input arrays must have the same shape") - - result_ndim = arrays[0].ndim + 1 - axis = _util.normalize_axis_index(axis, result_ndim) - - sl = (slice(None),) * axis + (newaxis,) - expanded_arrays = [arr[sl] for arr in arrays] - return concatenate( - expanded_arrays, axis=axis, out=out, dtype=dtype, casting=casting - ) + tensors = _helpers.to_tensors(*arrays) + _concat_check(tensors, dtype, out=out) + result = _impl.stack(tensors, axis=axis, out=out, dtype=dtype, casting=casting) + return _helpers.result_or_out(result, out) def array_split(ary, indices_or_sections, axis=0): @@ -471,27 +485,6 @@ def cov( return asarray(result) -@_decorators.dtype_to_torch -def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"): - if ar_tuple == (): - # XXX: RuntimeError in torch, ValueError in numpy - raise ValueError("need at least one array to concatenate") - - if out is not None: - if not isinstance(out, ndarray): - raise ValueError("'out' must be an array") - - if dtype is not None: - # mimic numpy - raise TypeError( - "concatenate() only takes `out` or `dtype` as an " - "argument, but both were provided." - ) - tensors = _helpers.to_tensors(*ar_tuple) - result = _impl.concatenate(tensors, axis, out, dtype, casting) - return _helpers.result_or_out(result, out) - - def bincount(x, /, weights=None, minlength=0): if not isinstance(x, ndarray) and x == []: # edge case allowed by numpy From 8f10e1b46e98579848e32b16c0eb2d0f6b031ddc Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 17 Feb 2023 11:42:45 +0300 Subject: [PATCH 10/20] MAINT: split arange/geomspace + full/empty etc --- torch_np/_detail/implementations.py | 100 ++++++++++++++++++++++++++-- torch_np/_wrapper.py | 77 +++++---------------- 2 files changed, 114 insertions(+), 63 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 9bc34521..c5d4b8cf 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -138,7 +138,8 @@ def split_helper(tensor, indices_or_sections, axis, strict=False): if isinstance(indices_or_sections, int): return split_helper_int(tensor, indices_or_sections, axis, strict) elif isinstance(indices_or_sections, (list, tuple)): - return split_helper_list(tensor, list(indices_or_sections), axis, strict) + # NB: drop split=..., it only applies to split_helper_int + return split_helper_list(tensor, list(indices_or_sections), axis) else: raise TypeError("split_helper: ", type(indices_or_sections)) @@ -170,7 +171,7 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False): return result -def split_helper_list(tensor, indices_or_sections, axis, strict=False): +def split_helper_list(tensor, indices_or_sections, axis): if not isinstance(indices_or_sections, list): raise NotImplementedError("split: indices_or_sections: list") # numpy expectes indices, while torch expects lengths of sections @@ -256,7 +257,9 @@ def stack(tensors, axis=0, out=None, *, dtype=None, casting="same_kind"): sl = (slice(None),) * axis + (None,) expanded_tensors = [tensor[sl] for tensor in tensors] - result = concatenate(expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting) + result = concatenate( + expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting + ) return result @@ -374,10 +377,99 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"): return output - def bincount(x_tensor, /, weights_tensor=None, minlength=0): int_dtype = _dtypes_impl.default_int_dtype (x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe") result = torch.bincount(x_tensor, weights_tensor, minlength) return result + + +# ### linspace, geomspace, logspace and arange ### + + +def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): + if axis != 0 or not endpoint: + raise NotImplementedError + tstart, tstop = torch.as_tensor([start, stop]) + base = torch.pow(tstop / tstart, 1.0 / (num - 1)) + result = torch.logspace( + torch.log(tstart) / torch.log(base), + torch.log(tstop) / torch.log(base), + num, + base=base, + ) + return result + + +def arange(start=None, stop=None, step=1, dtype=None): + if step == 0: + raise ZeroDivisionError + if stop is None and start is None: + raise TypeError + if stop is None: + # XXX: this breaks if start is passed as a kwarg: + # arange(start=4) should raise (no stop) but doesn't + start, stop = 0, start + if start is None: + start = 0 + + if dtype is None: + dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] + dtype = _dtypes_impl.default_int_dtype + dt_list.append(dtype) + dtype = _dtypes_impl.result_type_impl(dt_list) + + try: + return torch.arange(start, stop, step, dtype=dtype) + except RuntimeError: + raise ValueError("Maximum allowed size exceeded") + + +# ### empty/full et al ### + + +def eye(N, M=None, k=0, dtype=float): + if M is None: + M = N + z = torch.zeros(N, M, dtype=dtype) + z.diagonal(k).fill_(1) + return z + + +def zeros_like(a, dtype=None, shape=None): + result = torch.zeros_like(a, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def ones_like(a, dtype=None, shape=None): + result = torch.ones_like(a, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def full_like(a, fill_value, dtype=None, shape=None): + # XXX: fill_value broadcasts + result = torch.full_like(a, fill_value, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def empty_like(prototype, dtype=None, shape=None): + result = torch.empty_like(prototype, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def full(shape, fill_value, dtype=None): + if dtype is None: + dtype = fill_value.dtype + if not isinstance(shape, (tuple, list)): + shape = (shape,) + result = torch.full(shape, fill_value, dtype=dtype) + return result diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index a99a7c29..246e93ae 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -151,7 +151,7 @@ def dstack(tup, *, dtype=None, casting="same_kind"): # but {h,v}stack do. Hence add them here for consistency. tensors = _helpers.to_tensors(*tup) result = _impl.dstack(tensors, dtype=dtype, casting=casting) - return asarray(result) + return asarray(result) @_decorators.dtype_to_torch @@ -257,49 +257,27 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis return asarray(torch.linspace(start, stop, num, dtype=dtype)) +@_decorators.dtype_to_torch def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): if axis != 0 or not endpoint: raise NotImplementedError - tstart, tstop = torch.as_tensor([start, stop]) - base = torch.pow(tstop / tstart, 1.0 / (num - 1)) - result = torch.logspace( - torch.log(tstart) / torch.log(base), - torch.log(tstop) / torch.log(base), - num, - base=base, - ) + result = _impl.geomspace(start, stop, num, endpoint, dtype, axis) return asarray(result) +@_decorators.dtype_to_torch def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): if axis != 0 or not endpoint: raise NotImplementedError return asarray(torch.logspace(start, stop, num, base=base, dtype=dtype)) +@_decorators.dtype_to_torch def arange(start=None, stop=None, step=1, dtype=None, *, like=None): _util.subok_not_ok(like) - if step == 0: - raise ZeroDivisionError - if stop is None and start is None: - raise TypeError - if stop is None: - # XXX: this breaks if start is passed as a kwarg: - # arange(start=4) should raise (no stop) but doesn't - start, stop = 0, start - if start is None: - start = 0 - - if dtype is None: - dtype = _dtypes.default_int_type() - dtype = result_type(start, stop, step, dtype) - torch_dtype = _dtypes.torch_dtype_from(dtype) start, stop, step = _helpers.ndarrays_to_tensors(start, stop, step) - - try: - return asarray(torch.arange(start, stop, step, dtype=torch_dtype)) - except RuntimeError: - raise ValueError("Maximum allowed size exceeded") + result = _impl.arange(start, stop, step, dtype=dtype) + return asarray(result) @_decorators.dtype_to_torch @@ -316,14 +294,12 @@ def empty(shape, dtype=float, order="C", *, like=None): # NB: *_like function deliberately deviate from numpy: it has subok=True # as the default; we set subok=False and raise on anything else. @asarray_replacer() +@_decorators.dtype_to_torch def empty_like(prototype, dtype=None, order="K", subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != "K": raise NotImplementedError - torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) - result = torch.empty_like(prototype, dtype=torch_dtype) - if shape is not None: - result = result.reshape(shape) + result = _impl.empty_like(prototype, dtype=dtype, shape=shape) return result @@ -332,28 +308,18 @@ def full(shape, fill_value, dtype=None, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError - fill_value = asarray(fill_value).get() - if dtype is None: - dtype = fill_value.dtype - - if not isinstance(shape, (tuple, list)): - shape = (shape,) - - result = torch.full(shape, fill_value, dtype=dtype) - + result = _impl.full(shape, fill_value, dtype=dtype) return asarray(result) @asarray_replacer() +@_decorators.dtype_to_torch def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != "K": raise NotImplementedError - torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) - result = torch.full_like(a, fill_value, dtype=torch_dtype) - if shape is not None: - result = result.reshape(shape) + result = _impl.full_like(a, fill_value, dtype=dtype, shape=shape) return result @@ -369,14 +335,12 @@ def ones(shape, dtype=None, order="C", *, like=None): @asarray_replacer() +@_decorators.dtype_to_torch def ones_like(a, dtype=None, order="K", subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != "K": raise NotImplementedError - torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) - result = torch.ones_like(a, dtype=torch_dtype) - if shape is not None: - result = result.reshape(shape) + result = _impl.ones_like(a, dtype=dtype, shape=shape) return result @@ -392,14 +356,12 @@ def zeros(shape, dtype=None, order="C", *, like=None): @asarray_replacer() +@_decorators.dtype_to_torch def zeros_like(a, dtype=None, order="K", subok=False, shape=None): _util.subok_not_ok(subok=subok) if order != "K": raise NotImplementedError - torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype) - result = torch.zeros_like(a, dtype=torch_dtype) - if shape is not None: - result = result.reshape(shape) + result = _impl.zeros_like(a, dtype=dtype, shape=shape) return result @@ -408,11 +370,8 @@ def eye(N, M=None, k=0, dtype=float, order="C", *, like=None): _util.subok_not_ok(like) if order != "C": raise NotImplementedError - if M is None: - M = N - z = torch.zeros(N, M, dtype=dtype) - z.diagonal(k).fill_(1) - return asarray(z) + result = _impl.eye(N, M, k, dtype) + return asarray(result) def identity(n, dtype=None, *, like=None): From 50e0dd98e7357ea54a00b9fc922a284942b92151 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 17 Feb 2023 14:30:35 +0300 Subject: [PATCH 11/20] MAINT: deduplicate y != None in cov, corrcoef --- torch_np/_detail/implementations.py | 6 +-- torch_np/_wrapper.py | 59 ++++++++++++++++------------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index c5d4b8cf..fe279d1c 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -301,11 +301,7 @@ def vstack(tensors, *, dtype=None, casting="same_kind"): # #### cov & corrcoef -def corrcoef(xy_tensor, rowvar=True, *, dtype=None): - if rowvar is False: - # xy_tensor is at least 2D, so using .T is safe - xy_tensor = x_tensor.T - +def corrcoef(xy_tensor, *, dtype=None): is_half = dtype == torch.float16 if is_half: # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 246e93ae..d75abe76 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -388,26 +388,40 @@ def diag(v, k=0): ###### misc/unordered +def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True): + """Prepate inputs for cov and corrcoef.""" + + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636 + if y_tensor is not None: + # make sure x and y are at least 2D + ndim_extra = 2 - x_tensor.ndim + if ndim_extra > 0: + x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape) + if not rowvar and x_tensor.shape[0] != 1: + x_tensor = x_tensor.mT + x_tensor = x_tensor.clone() + + ndim_extra = 2 - y_tensor.ndim + if ndim_extra > 0: + y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape) + if not rowvar and y_tensor.shape[0] != 1: + y_tensor = y_tensor.mT + y_tensor = y_tensor.clone() + + x_tensor = _impl.concatenate((x_tensor, y_tensor), axis=0) + + return x_tensor + + @_decorators.dtype_to_torch def corrcoef(x, y=None, rowvar=True, bias=NoValue, ddof=NoValue, *, dtype=None): if bias is not None or ddof is not None: # deprecated in NumPy raise NotImplementedError - # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636 - if y is not None: - x = array(x, ndmin=2) - if not rowvar and x.shape[0] != 1: - x = x.T - - y = array(y, ndmin=2) - if not rowvar and y.shape[0] != 1: - y = y.T - - x = concatenate((x, y), axis=0) - - x_tensor = asarray(x).get() - result = _impl.corrcoef(x_tensor, rowvar, dtype=dtype) + x_tensor, y_tensor = _helpers.to_tensors_or_none(x, y) + tensor = _xy_helper_corrcoef(x_tensor, y_tensor, rowvar) + result = _impl.corrcoef(tensor, dtype=dtype) return asarray(result) @@ -423,21 +437,12 @@ def cov( *, dtype=None, ): - # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636 - if y is not None: - m = array(m, ndmin=2) - if not rowvar and m.shape[0] != 1: - m = m.T - - y = array(y, ndmin=2) - if not rowvar and y.shape[0] != 1: - y = y.T - m = concatenate((m, y), axis=0) - - m_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none( - m, fweights, aweights + m_tensor, y_tensor, fweights_tensor, aweights_tensor = _helpers.to_tensors_or_none( + m, y, fweights, aweights ) + m_tensor = _xy_helper_corrcoef(m_tensor, y_tensor, rowvar) + result = _impl.cov( m_tensor, bias, ddof, fweights_tensor, aweights_tensor, dtype=dtype ) From 59c65969488eb40102fa0f9c07716ef89b9eb16b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 17 Feb 2023 14:42:36 +0300 Subject: [PATCH 12/20] MAINT: remove a redundant one-liner --- torch_np/_dtypes.py | 4 ---- torch_np/_getlimits.py | 4 ++-- torch_np/_ndarray.py | 4 ++-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/torch_np/_dtypes.py b/torch_np/_dtypes.py index c61db2c0..ce1cf13f 100644 --- a/torch_np/_dtypes.py +++ b/torch_np/_dtypes.py @@ -254,10 +254,6 @@ def dtype(arg): return DType(arg) -def torch_dtype_from(dtype_arg): - return dtype(dtype_arg).torch_dtype - - class DType: def __init__(self, arg): # a pytorch object? diff --git a/torch_np/_getlimits.py b/torch_np/_getlimits.py index ba36d2c1..229c8963 100644 --- a/torch_np/_getlimits.py +++ b/torch_np/_getlimits.py @@ -4,12 +4,12 @@ def finfo(dtyp): - torch_dtype = _dtypes.torch_dtype_from(dtyp) + torch_dtype = _dtypes.dtype(dtyp).torch_dtype return torch.finfo(torch_dtype) def iinfo(dtyp): - torch_dtype = _dtypes.torch_dtype_from(dtyp) + torch_dtype = _dtypes.dtype(dtyp).torch_dtype return torch.iinfo(torch_dtype) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 7b95f15a..f310af7d 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -149,7 +149,7 @@ def round(self, decimals=0, out=None): # ctors def astype(self, dtype): newt = ndarray() - torch_dtype = _dtypes.torch_dtype_from(dtype) + torch_dtype = _dtypes.dtype(dtype).torch_dtype newt._tensor = self._tensor.to(torch_dtype) return newt @@ -439,7 +439,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N # is a specific dtype requrested? torch_dtype = None if dtype is not None: - torch_dtype = _dtypes.torch_dtype_from(dtype) + torch_dtype = _dtypes.dtype(dtype).torch_dtype base = None tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) From 69135b9bf12ee3ef1d4fa646d56c322bc981aea8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 17 Feb 2023 18:17:36 +0300 Subject: [PATCH 13/20] BUG: isreal for real args is all-true Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> --- torch_np/_detail/implementations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index fe279d1c..33847f0a 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -54,7 +54,7 @@ def tensor_iscomplex(x): def tensor_isreal(x): if torch.is_complex(x): return torch.as_tensor(x).imag == 0 - result = torch.zeros_like(x, dtype=torch.bool) + result = torch.ones_like(x, dtype=torch.bool) if result.ndim == 0: result = result.item() return result From 85d1ba66fa5d874ff87d67559fe86693444b31a7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 17 Feb 2023 18:19:25 +0300 Subject: [PATCH 14/20] TST: test isreal(real array) --- .../tests/numpy_tests/lib/test_type_check.py | 43 +++---------------- 1 file changed, 5 insertions(+), 38 deletions(-) diff --git a/torch_np/tests/numpy_tests/lib/test_type_check.py b/torch_np/tests/numpy_tests/lib/test_type_check.py index 8c9d9df4..4567d4b4 100644 --- a/torch_np/tests/numpy_tests/lib/test_type_check.py +++ b/torch_np/tests/numpy_tests/lib/test_type_check.py @@ -157,7 +157,6 @@ def test_cmplx(self): assert_(not isinstance(out, np.ndarray)) -@pytest.mark.xfail(reason="not implemented") class TestIscomplex: def test_fail(self): @@ -171,7 +170,6 @@ def test_pass(self): assert_array_equal(res, [1, 0, 0]) -@pytest.mark.xfail(reason="not implemented") class TestIsreal: def test_pass(self): @@ -184,6 +182,11 @@ def test_fail(self): res = isreal(z) assert_array_equal(res, [0, 1, 1]) + def test_isreal_real(self): + z = np.array([-1, 0, 1]) + res = isreal(z) + assert res.all() + @pytest.mark.xfail(reason="not implemented") class TestIscomplexobj: @@ -202,42 +205,6 @@ def test_list(self): assert_(iscomplexobj([3, 1+0j, True])) assert_(not iscomplexobj([3, 1, True])) - def test_duck(self): - class DummyComplexArray: - @property - def dtype(self): - return np.dtype(complex) - dummy = DummyComplexArray() - assert_(iscomplexobj(dummy)) - - def test_pandas_duck(self): - # This tests a custom np.dtype duck-typed class, such as used by pandas - # (pandas.core.dtypes) - class PdComplex(np.complex128): - pass - class PdDtype: - name = 'category' - names = None - type = PdComplex - kind = 'c' - str = ' Date: Fri, 17 Feb 2023 18:38:16 +0300 Subject: [PATCH 15/20] MAINT: address review comments --- torch_np/_detail/implementations.py | 25 +++++++++++++------------ torch_np/_wrapper.py | 15 ++++++++------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 33847f0a..5f7fedb8 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -31,7 +31,7 @@ def tensor_equiv(a1_t, a2_t): return tensor_equal(a1_t, a2_t) -def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): +def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) a = a.to(dtype) b = b.to(dtype) @@ -42,7 +42,7 @@ def tensor_isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): # ### is arg real or complex valued ### -def tensor_iscomplex(x): +def iscomplex(x): if torch.is_complex(x): return torch.as_tensor(x).imag != 0 result = torch.zeros_like(x, dtype=torch.bool) @@ -51,7 +51,7 @@ def tensor_iscomplex(x): return result -def tensor_isreal(x): +def isreal(x): if torch.is_complex(x): return torch.as_tensor(x).imag == 0 result = torch.ones_like(x, dtype=torch.bool) @@ -60,7 +60,8 @@ def tensor_isreal(x): return result -def tensor_real_if_close(x, tol=100): +def real_if_close(x, tol=100): + # XXX: copies vs views; numpy seems to return a copy? if not torch.is_complex(x): return x mask = torch.abs(x.imag) < tol * torch.finfo(x.dtype).eps @@ -73,20 +74,20 @@ def tensor_real_if_close(x, tol=100): # ### math functions ### -def tensor_angle(z, deg=False): +def angle(z, deg=False): result = torch.angle(z) if deg: - result *= 180 / torch.pi + result = result * 180 / torch.pi return result # ### sorting ### -def tensor_argsort(tensor, axis=-1, kind=None, order=None): +def argsort(tensor, axis=-1, kind=None, order=None): if order is not None: raise NotImplementedError - stable = True if kind == "stable" else False + stable = kind == "stable" if axis is None: axis = -1 return torch.argsort(tensor, stable=stable, dim=axis, descending=False) @@ -387,11 +388,11 @@ def bincount(x_tensor, /, weights_tensor=None, minlength=0): def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): if axis != 0 or not endpoint: raise NotImplementedError - tstart, tstop = torch.as_tensor([start, stop]) - base = torch.pow(tstop / tstart, 1.0 / (num - 1)) + base = torch.pow(stop / start, 1.0 / (num - 1)) + logbase = torch.log(base) result = torch.logspace( - torch.log(tstart) / torch.log(base), - torch.log(tstop) / torch.log(base), + torch.log(start) / logbase, + torch.log(stop) / logbase, num, base=base, ) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index d75abe76..775d2263 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -261,6 +261,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): if axis != 0 or not endpoint: raise NotImplementedError + start, stop = _helpers.to_tensors(start, stop) result = _impl.geomspace(start, stop, num, endpoint, dtype, axis) return asarray(result) @@ -954,7 +955,7 @@ def diff(a, n=1, axis=-1, prepend=NoValue, append=NoValue): @asarray_replacer() def argsort(a, axis=-1, kind=None, order=None): - result = _impl.tensor_argsort(a, axis, kind, order) + result = _impl.argsort(a, axis, kind, order) return result @@ -963,7 +964,7 @@ def argsort(a, axis=-1, kind=None, order=None): @asarray_replacer() def angle(z, deg=False): - result = _impl.tensor_angle(z, deg) + result = _impl.angle(z, deg) return result @@ -984,19 +985,19 @@ def imag(a): @asarray_replacer() def real_if_close(a, tol=100): - result = _impl.tensor_real_if_close(a, tol=tol) + result = _impl.real_if_close(a, tol=tol) return result @asarray_replacer() def iscomplex(x): - result = _impl.tensor_iscomplex(x) + result = _impl.iscomplex(x) return result # XXX: missing .item on a zero-dim value; a case for array_or_scalar(value) ? @asarray_replacer() def isreal(x): - result = _impl.tensor_isreal(x) + result = _impl.isreal(x) return result @@ -1036,13 +1037,13 @@ def isscalar(a): def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): a_t, b_t = _helpers.to_tensors(a, b) - result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) + result = _impl.isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) return asarray(result) def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): a_t, b_t = _helpers.to_tensors(a, b) - result = _impl.tensor_isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) + result = _impl.isclose(a_t, b_t, rtol, atol, equal_nan=equal_nan) return result.all() From 1b3708a5e3eb2e1a6203b11b9c59d32a727fe6d3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 17 Feb 2023 19:44:22 +0300 Subject: [PATCH 16/20] MAINT: add a utility for to wrap tensor.to(dtype != tensor.dtype) --- torch_np/_detail/_reductions.py | 17 ++++++----------- torch_np/_detail/_util.py | 18 ++++++++++-------- torch_np/_detail/implementations.py | 11 ++++------- 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/torch_np/_detail/_reductions.py b/torch_np/_detail/_reductions.py index 88c5a910..d7688787 100644 --- a/torch_np/_detail/_reductions.py +++ b/torch_np/_detail/_reductions.py @@ -146,9 +146,7 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): raise NotImplementedError dtype = _atleast_float(dtype, tensor.dtype) - - if dtype is not None: - tensor = tensor.to(dtype) + tensor = _util.cast_if_needed(tensor, dtype) result = tensor.std(dim=axis, correction=ddof) return result @@ -159,9 +157,7 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue): raise NotImplementedError dtype = _atleast_float(dtype, tensor.dtype) - - if dtype is not None: - tensor = tensor.to(dtype) + tensor = _util.cast_if_needed(tensor, dtype) result = tensor.var(dim=axis, correction=ddof) return result @@ -204,10 +200,9 @@ def average(a_tensor, axis, w_tensor): a_tensor = a_tensor.to(result_dtype) result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype]) - if a_tensor.dtype != result_dtype: - a_tensor = a_tensor.to(result_dtype) - if w_tensor.dtype != result_dtype: - w_tensor = w_tensor.to(result_dtype) + + a_tensor = _util.cast_if_needed(a_tensor, result_dtype) + w_tensor = _util.cast_if_needed(w_tensor, result_dtype) # axis if axis is None: @@ -258,7 +253,7 @@ def quantile(a_tensor, q_tensor, axis, method): axis = _util.normalize_axis_tuple(axis, a_tensor.ndim) axis = _util.allow_only_single_axis(axis) - q_tensor = q_tensor.to(a_tensor.dtype) + q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype) (a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis) diff --git a/torch_np/_detail/_util.py b/torch_np/_detail/_util.py index 50ed3a89..ef5f9a7e 100644 --- a/torch_np/_detail/_util.py +++ b/torch_np/_detail/_util.py @@ -34,6 +34,13 @@ class UFuncTypeError(TypeError, RuntimeError): pass +def cast_if_needed(tensor, dtype): + # NB: no casting if dtype=None + if tensor.dtype != dtype: + tensor = tensor.to(dtype) + return tensor + + # a replica of the version in ./numpy/numpy/core/src/multiarray/common.h def normalize_axis_index(ax, ndim, argname=None): if not (-ndim <= ax < ndim): @@ -156,10 +163,7 @@ def cast_dont_broadcast(tensors, target_dtype, casting): f"Cannot cast array data from {tensor.dtype} to" f" {target_dtype} according to the rule '{casting}'" ) - - # cast if needed - if tensor.dtype != target_dtype: - tensor = tensor.to(target_dtype) + tensor = cast_if_needed(tensor, target_dtype) cast_tensors.append(tensor) return tuple(cast_tensors) @@ -200,8 +204,7 @@ def cast_and_broadcast(tensors, out_param, casting): ) # cast arr if needed - if tensor.dtype != target_dtype: - tensor = tensor.to(target_dtype) + tensor = cast_if_needed(tensor, target_dtype) # `out` broadcasts `tensor` if tensor.shape != target_shape: @@ -285,8 +288,7 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): tensor = torch.as_tensor(obj, dtype=torch_dtype) # type cast if requested - if dtype is not None: - tensor = tensor.to(dtype) + tensor = cast_if_needed(tensor, dtype) # adjust ndim if needed ndim_extra = ndmin - tensor.ndim diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 5f7fedb8..7e176f76 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -33,8 +33,8 @@ def tensor_equiv(a1_t, a2_t): def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) - a = a.to(dtype) - b = b.to(dtype) + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) result = torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return result @@ -308,9 +308,7 @@ def corrcoef(xy_tensor, *, dtype=None): # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" dtype = torch.float32 - if dtype is not None: - xy_tensor = xy_tensor.to(dtype) - + xy_tensor = _util.cast_if_needed(xy_tensor, dtype) result = torch.corrcoef(xy_tensor) if is_half: @@ -336,8 +334,7 @@ def cov( # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" dtype = torch.float32 - if dtype is not None: - m_tensor = m_tensor.to(dtype) + m_tensor = _util.cast_if_needed(m_tensor, dtype) result = torch.cov( m_tensor, correction=ddof, aweights=aweights_tensor, fweights=fweights_tensor From 1dbacff3151ce003dd0de8dc76176bd1b2e6b581 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 17 Feb 2023 21:35:01 +0300 Subject: [PATCH 17/20] MAINT: reimplement {v,d,h, column_}stack through their pytorch equivalents --- torch_np/_detail/implementations.py | 67 +++++++++++++---------------- 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 7e176f76..805d2d80 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -229,73 +229,66 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None): # #### concatenate and relatives -def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): - # np.concatenate ravels if axis=None - tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) +def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): + """Figure out dtypes, cast if necessary.""" if out is not None or dtype is not None: # figure out the type of the inputs and outputs out_dtype = out.dtype.torch_dtype if dtype is None else dtype + else: + out_dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors]) + + # cast input arrays if necessary; do not broadcast them agains `out` + tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting) - # cast input arrays if necessary; do not broadcast them agains `out` - tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting) + return tensors + + +def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): + # np.concatenate ravels if axis=None + tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) + tensors = _concat_cast_helper(tensors, out, dtype, casting) try: result = torch.cat(tensors, axis) - except (IndexError, RuntimeError): - raise _util.AxisError + except (IndexError, RuntimeError) as e: + raise _util.AxisError(*e.args) return result def stack(tensors, axis=0, out=None, *, dtype=None, casting="same_kind"): - shapes = {t.shape for t in tensors} - if len(shapes) != 1: - raise ValueError("all input arrays must have the same shape") - + tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting) result_ndim = tensors[0].ndim + 1 axis = _util.normalize_axis_index(axis, result_ndim) - - sl = (slice(None),) * axis + (None,) - expanded_tensors = [tensor[sl] for tensor in tensors] - result = concatenate( - expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting - ) - + try: + result = torch.stack(tensors, axis=axis) + except RuntimeError as e: + raise ValueError(*e.args) return result -def column_stack(tensors_, *, dtype=None, casting="same_kind"): - tensors = [] - for t in tensors_: - if t.ndim < 2: - t = _util._coerce_to_tensor(t, copy=False, ndmin=2).mT - tensors.append(t) - - result = concatenate(tensors, 1, dtype=dtype, casting=casting) +def column_stack(tensors, *, dtype=None, casting="same_kind"): + tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting) + result = torch.column_stack(tensors) return result def dstack(tensors, *, dtype=None, casting="same_kind"): - tensors = torch.atleast_3d(tensors) - result = concatenate(tensors, 2, dtype=dtype, casting=casting) + tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting) + result = torch.dstack(tensors) return result def hstack(tensors, *, dtype=None, casting="same_kind"): - tensors = torch.atleast_1d(tensors) - - # As a special case, dimension 0 of 1-dimensional arrays is "horizontal"s - if tensors and tensors[0].ndim == 1: - result = concatenate(tensors, 0, dtype=dtype, casting=casting) - else: - result = concatenate(tensors, 1, dtype=dtype, casting=casting) + tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting) + result = torch.hstack(tensors) return result def vstack(tensors, *, dtype=None, casting="same_kind"): - tensors = torch.atleast_2d(tensors) - result = concatenate(tensors, 0, dtype=dtype, casting=casting) + tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting) + result = torch.vstack(tensors) return result From 7724e7bf055befdd8780b654a86e8fdf057458e1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 18 Feb 2023 12:51:35 +0300 Subject: [PATCH 18/20] BUG: fix tolerance of real_if_close --- torch_np/_detail/implementations.py | 8 +++++++- torch_np/tests/numpy_tests/lib/test_type_check.py | 11 ++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 805d2d80..4f915f39 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -64,7 +64,13 @@ def real_if_close(x, tol=100): # XXX: copies vs views; numpy seems to return a copy? if not torch.is_complex(x): return x - mask = torch.abs(x.imag) < tol * torch.finfo(x.dtype).eps + if tol > 1: + # Undocumented in numpy: if tol < 1, it's an absolute tolerance! + # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577 + tol = tol * torch.finfo(x.dtype).eps + + mask = torch.abs(x.imag) < tol if mask.all(): return x.real else: diff --git a/torch_np/tests/numpy_tests/lib/test_type_check.py b/torch_np/tests/numpy_tests/lib/test_type_check.py index 4567d4b4..341ee1bc 100644 --- a/torch_np/tests/numpy_tests/lib/test_type_check.py +++ b/torch_np/tests/numpy_tests/lib/test_type_check.py @@ -82,7 +82,7 @@ def test_default_3(self): assert_equal(mintypecode('idD'), 'D') -@pytest.mark.xfail(reason="not implemented") +@pytest.mark.xfail(reason="TODO: decide on if [1] is a scalar or not") class TestIsscalar: def test_basic(self): @@ -188,7 +188,6 @@ def test_isreal_real(self): assert res.all() -@pytest.mark.xfail(reason="not implemented") class TestIscomplexobj: def test_basic(self): @@ -206,7 +205,7 @@ def test_list(self): assert_(not iscomplexobj([3, 1, True])) -@pytest.mark.xfail(reason="not implemented") + class TestIsrealobj: def test_basic(self): z = np.array([-1, 0, 1]) @@ -215,7 +214,6 @@ def test_basic(self): assert_(not isrealobj(z)) -@pytest.mark.xfail(reason="not implemented") class TestIsnan: def test_goodvalues(self): @@ -246,7 +244,6 @@ def test_complex1(self): assert_all(np.isnan(np.array(0+0j)/0.) == 1) -@pytest.mark.xfail(reason="not implemented") class TestIsfinite: # Fixme, wrong place, isfinite now ufunc @@ -278,7 +275,6 @@ def test_complex1(self): assert_all(np.isfinite(np.array(1+1j)/0.) == 0) -@pytest.mark.xfail(reason="not implemented") class TestIsinf: # Fixme, wrong place, isinf now ufunc @@ -308,7 +304,6 @@ def test_ind(self): assert_all(np.isinf(np.array((0.,))/0.) == 0) -@pytest.mark.xfail(reason="not implemented") class TestIsposinf: def test_generic(self): @@ -319,7 +314,6 @@ def test_generic(self): assert_(vals[2] == 1) -@pytest.mark.xfail(reason="not implemented") class TestIsneginf: def test_generic(self): @@ -436,7 +430,6 @@ def test_do_not_rewrite_previous_keyword(self): assert_equal(type(vals), np.ndarray) -@pytest.mark.xfail(reason="not implemented") class TestRealIfClose: def test_basic(self): From 6de03dd52ef16820742771e2c5231d0ae9fad085 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 19 Feb 2023 13:04:19 +0300 Subject: [PATCH 19/20] MAINT: move *split logic to _impl --- torch_np/_detail/implementations.py | 22 ++++++++++++++++++++++ torch_np/_wrapper.py | 27 +++------------------------ 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index 4f915f39..d688f7cb 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -155,6 +155,8 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False): if not isinstance(indices_or_sections, int): raise NotImplementedError("split: indices_or_sections") + axis = _util.normalize_axis_index(axis, tensor.ndim) + # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n l, n = tensor.shape[axis], indices_or_sections @@ -195,6 +197,26 @@ def split_helper_list(tensor, indices_or_sections, axis): return torch.split(tensor, lst, axis) +def hsplit(tensor, indices_or_sections): + if tensor.ndim == 0: + raise ValueError("hsplit only works on arrays of 1 or more dimensions") + axis = 1 if tensor.ndim > 1 else 0 + return split_helper(tensor, indices_or_sections, axis, strict=True) + + +def vsplit(tensor, indices_or_sections): + if tensor.ndim < 2: + raise ValueError("vsplit only works on arrays of 2 or more dimensions") + return split_helper(tensor, indices_or_sections, 0, strict=True) + + +def dsplit(tensor, indices_or_sections): + if tensor.ndim < 3: + raise ValueError("dsplit only works on arrays of 3 or more dimensions") + return split_helper(tensor, indices_or_sections, 2, strict=True) + + + def clip(tensor, t_min, t_max): if t_min is not None: t_min = torch.broadcast_to(t_min, tensor.shape) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 775d2263..6f3c6f86 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -176,56 +176,35 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"): def array_split(ary, indices_or_sections, axis=0): tensor = asarray(ary).get() base = ary if isinstance(ary, ndarray) else None - axis = _util.normalize_axis_index(axis, tensor.ndim) - result = _impl.split_helper(tensor, indices_or_sections, axis) - return tuple(maybe_set_base(x, base) for x in result) def split(ary, indices_or_sections, axis=0): tensor = asarray(ary).get() base = ary if isinstance(ary, ndarray) else None - axis = _util.normalize_axis_index(axis, tensor.ndim) - result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True) - return tuple(maybe_set_base(x, base) for x in result) def hsplit(ary, indices_or_sections): tensor = asarray(ary).get() base = ary if isinstance(ary, ndarray) else None - - if tensor.ndim == 0: - raise ValueError("hsplit only works on arrays of 1 or more dimensions") - - axis = 1 if tensor.ndim > 1 else 0 - - result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True) - + result = _impl.hsplit(tensor, indices_or_sections) return tuple(maybe_set_base(x, base) for x in result) def vsplit(ary, indices_or_sections): tensor = asarray(ary).get() base = ary if isinstance(ary, ndarray) else None - - if tensor.ndim < 2: - raise ValueError("vsplit only works on arrays of 2 or more dimensions") - result = _impl.split_helper(tensor, indices_or_sections, 0, strict=True) - + result = _impl.vsplit(tensor, indices_or_sections) return tuple(maybe_set_base(x, base) for x in result) def dsplit(ary, indices_or_sections): tensor = asarray(ary).get() base = ary if isinstance(ary, ndarray) else None - - if tensor.ndim < 3: - raise ValueError("dsplit only works on arrays of 3 or more dimensions") - result = _impl.split_helper(tensor, indices_or_sections, 2, strict=True) - + result = _impl.dsplit(tensor, indices_or_sections) return tuple(maybe_set_base(x, base) for x in result) From ba31a0b0a41ef72e27f8495a794286d45238dcd0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 19 Feb 2023 13:46:57 +0300 Subject: [PATCH 20/20] MAINT: move tensor manipulations from _ndarray to _impl --- torch_np/_detail/implementations.py | 62 ++++++++++++++++++++++++++++- torch_np/_ndarray.py | 34 ++++------------ torch_np/_wrapper.py | 15 +++---- 3 files changed, 74 insertions(+), 37 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index d688f7cb..86d5ecef 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -216,7 +216,6 @@ def dsplit(tensor, indices_or_sections): return split_helper(tensor, indices_or_sections, 2, strict=True) - def clip(tensor, t_min, t_max): if t_min is not None: t_min = torch.broadcast_to(t_min, tensor.shape) @@ -488,3 +487,64 @@ def full(shape, fill_value, dtype=None): shape = (shape,) result = torch.full(shape, fill_value, dtype=dtype) return result + + +# ### shape manipulations ### + + +def roll(tensor, shift, axis=None): + if axis is not None: + axis = _util.normalize_axis_tuple(axis, tensor.ndim, allow_duplicate=True) + if not isinstance(shift, tuple): + shift = (shift,) * len(axis) + result = tensor.roll(shift, axis) + return result + + +def squeeze(tensor, axis=None): + if axis == (): + result = tensor + elif axis is None: + result = tensor.squeeze() + else: + result = tensor.squeeze(axis) + return result + + +def reshape(tensor, *shape, order="C"): + if order != "C": + raise NotImplementedError + newshape = shape[0] if len(shape) == 1 else shape + # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh) + result = tensor.reshape(newshape) + return result + + +def transpose(tensor, *axes): + # numpy allows both .reshape(sh) and .reshape(*sh) + axes = axes[0] if len(axes) == 1 else axes + if axes == () or axes is None: + axes = tuple(range(tensor.ndim))[::-1] + try: + result = tensor.permute(axes) + except RuntimeError: + raise ValueError("axes don't match array") + return result + + +# ### Numeric ### + + +def round(tensor, decimals=0): + if tensor.is_floating_point(): + result = torch.round(tensor, decimals=decimals) + elif tensor.is_complex(): + # RuntimeError: "round_cpu" not implemented for 'ComplexFloat' + result = ( + torch.round(tensor.real, decimals=decimals) + + torch.round(tensor.imag, decimals=decimals) * 1j + ) + else: + # RuntimeError: "round_cpu" not implemented for 'int' + result = tensor + return result diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index f310af7d..7f0940b6 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -139,11 +139,7 @@ def imag(self, value): self._tensor.imag = asarray(value).get() def round(self, decimals=0, out=None): - tensor = self._tensor - if torch.is_floating_point(tensor): - result = torch.round(tensor, decimals=decimals) - else: - result = tensor + result = _impl.round(self._tensor, decimals) return _helpers.result_or_out(result, out) # ctors @@ -328,32 +324,16 @@ def __irshift__(self, other): ### methods to match namespace functions def squeeze(self, axis=None): - if axis == (): - tensor = self._tensor - elif axis is None: - tensor = self._tensor.squeeze() - else: - tensor = self._tensor.squeeze(axis) - return ndarray._from_tensor_and_base(tensor, self) + result = _impl.squeeze(self._tensor, axis) + return ndarray._from_tensor_and_base(result, self) def reshape(self, *shape, order="C"): - newshape = shape[0] if len(shape) == 1 else shape - # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh) - if order != "C": - raise NotImplementedError - tensor = self._tensor.reshape(newshape) - return ndarray._from_tensor_and_base(tensor, self) + result = _impl.reshape(self._tensor, *shape, order=order) + return ndarray._from_tensor_and_base(result, self) def transpose(self, *axes): - # numpy allows both .reshape(sh) and .reshape(*sh) - axes = axes[0] if len(axes) == 1 else axes - if axes == () or axes is None: - axes = tuple(range(self.ndim))[::-1] - try: - tensor = self._tensor.permute(axes) - except RuntimeError: - raise ValueError("axes don't match array") - return ndarray._from_tensor_and_base(tensor, self) + result = _impl.transpose(self._tensor, *axes) + return ndarray._from_tensor_and_base(result, self) def swapaxes(self, axis1, axis2): return asarray(_flips.swapaxes(self._tensor, axis1, axis2)) diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 6f3c6f86..463c8319 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -595,9 +595,9 @@ def flatnonzero(a): def argwhere(a): - arr = asarray(a) - tensor = arr.get() - return asarray(torch.argwhere(tensor)) + tensor = asarray(a).get() + result = torch.argwhere(tensor) + return asarray(result) from ._decorators import emulate_out_arg @@ -606,13 +606,10 @@ def argwhere(a): count_nonzero = emulate_out_arg(axis_keepdims_wrapper(_reductions.count_nonzero)) -@asarray_replacer() def roll(a, shift, axis=None): - if axis is not None: - axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) - if not isinstance(shift, tuple): - shift = (shift,) * len(axis) - return a.roll(shift, axis) + tensor = asarray(a).get() + result = _impl.roll(tensor, shift, axis) + return asarray(result) def round_(a, decimals=0, out=None):