From 69be71f6d0019de83bd0c284cce26aa948220467 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 10 Jan 2023 16:43:53 +0300 Subject: [PATCH 1/4] MAINT: add prod, var, std to follow sum --- torch_np/_helpers.py | 10 ++++++ torch_np/_ndarray.py | 57 +++++++++++++++++++++++------- torch_np/_wrapper.py | 58 +++++++++++-------------------- torch_np/tests/test_reductions.py | 21 +++++++++++ 4 files changed, 96 insertions(+), 50 deletions(-) diff --git a/torch_np/_helpers.py b/torch_np/_helpers.py index 00b8faa7..75d0f821 100644 --- a/torch_np/_helpers.py +++ b/torch_np/_helpers.py @@ -131,3 +131,13 @@ def to_tensors(*inputs): """Convert all ndarrays from `inputs` to tensors.""" return tuple([value.get() if isinstance(value, ndarray) else value for value in inputs]) + + +def float_or_default(dtype, self_dtype): + """dtype helper for reductions.""" + if dtype is None: + dtype = self_dtype + if _dtypes.is_integer(dtype): + dtype = _dtypes.default_float_type() + torch_dtype = _dtypes.torch_dtype_from(dtype) + return torch_dtype diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 16a0691a..b9e8bb2c 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -416,12 +416,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal if where is not None: raise NotImplementedError - if dtype is None: - dtype = self.dtype - if _dtypes.is_integer(dtype): - dtype = _dtypes.default_float_type() - torch_dtype = _dtypes.torch_dtype_from(dtype) - + torch_dtype = _helpers.float_or_default(dtype, self.dtype) if axis is None: result = self._tensor.mean(dtype=torch_dtype) else: @@ -436,12 +431,7 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue, if initial is not None or where is not None: raise NotImplementedError - if dtype is None: - dtype = self.dtype - if _dtypes.is_integer(dtype): - dtype = _dtypes.default_float_type() - torch_dtype = _dtypes.torch_dtype_from(dtype) - + torch_dtype = _helpers.float_or_default(dtype, self.dtype) if axis is None: result = self._tensor.sum(dtype=torch_dtype) else: @@ -449,6 +439,49 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue, return result + @axis_out_keepdims_wrapper + def prod(self, axis=None, dtype=None, out=None, keepdims=NoValue, + initial=NoValue, where=NoValue): + if initial is not None or where is not None: + raise NotImplementedError + + axis = _helpers.allow_only_single_axis(axis) + + torch_dtype = _helpers.float_or_default(dtype, self.dtype) + if axis is None: + result = self._tensor.prod(dtype=torch_dtype) + else: + result = self._tensor.prod(dtype=torch_dtype, dim=axis) + + return result + + + @axis_out_keepdims_wrapper + def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, + where=NoValue): + if where is not None: + raise NotImplementedError + + torch_dtype = _helpers.float_or_default(dtype, self.dtype) + tensor = self._tensor.to(torch_dtype) # XXX: needed? + + result = tensor.std(dim=axis, correction=ddof) + + return result + + @axis_out_keepdims_wrapper + def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, + where=NoValue): + if where is not None: + raise NotImplementedError + + torch_dtype = _helpers.float_or_default(dtype, self.dtype) + tensor = self._tensor.to(torch_dtype) # XXX: needed? + + result = tensor.var(dim=axis, correction=ddof) + + return result + ### indexing ### def __getitem__(self, *args, **kwds): diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 9070718f..92592cc0 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -307,20 +307,6 @@ def identity(n, dtype=None, *, like=None): ###### misc/unordered -#YYY: pattern: initial=... -@asarray_replacer() -def prod(a, axis=None, dtype=None, out=None, keepdims=NoValue, - initial=NoValue, where=NoValue): - if initial is not None or where is not None: - raise NotImplementedError - if axis is None: - if keepdims is not None: - raise NotImplementedError - return torch.prod(a, dtype=dtype) - elif _util.is_sequence(axis): - raise NotImplementedError - return torch.prod(a, dim=axis, dtype=dtype, keepdim=bool(keepdims), out=out) - @asarray_replacer() @@ -639,6 +625,8 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue) return arr.mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) +#YYY: pattern: initial=... + def sum(a, axis=None, dtype=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue): arr = asarray(a) @@ -646,6 +634,24 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=NoValue, initial=initial, where=where) +def prod(a, axis=None, dtype=None, out=None, keepdims=NoValue, + initial=NoValue, where=NoValue): + arr = asarray(a) + return arr.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where=where) + + +#YYY: pattern : ddof + +def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue): + arr = asarray(a) + return arr.std(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where) + +def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue): + arr = asarray(a) + return arr.var(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where) + + @asarray_replacer() def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue): if where is not None: @@ -663,30 +669,6 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal return result -# YYY: pattern : std, var -@asarray_replacer() -def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue): - if where is not None: - raise NotImplementedError - if dtype is not None: - raise NotImplementedError - if not torch.is_floating_point(a): - a = a * 1.0 - return torch.std(a, axis, correction=ddof, keepdim=bool(keepdims), out=out) - - -@asarray_replacer() -def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue): - if where is not None: - raise NotImplementedError - if dtype is not None: - raise NotImplementedError - if not torch.is_floating_point(a): - a = a * 1.0 - return torch.var(a, axis, correction=ddof, keepdim=bool(keepdims), out=out) - - - @asarray_replacer() def argsort(a, axis=-1, kind=None, order=None): if order is not None: diff --git a/torch_np/tests/test_reductions.py b/torch_np/tests/test_reductions.py index 108b358e..e0ef345c 100644 --- a/torch_np/tests/test_reductions.py +++ b/torch_np/tests/test_reductions.py @@ -593,3 +593,24 @@ def setup_method(self): self.allowed_axes = [0, 1, 2, -1, -2, (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)] + +class TestProdGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin): + def setup_method(self): + self.func = np.prod + self.allowed_axes = [0, 1, 2, -1, -2,] +# (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)] + + +class TestStdGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin): + def setup_method(self): + self.func = np.std + self.allowed_axes = [0, 1, 2, -1, -2, + (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)] + + +class TestVarGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin): + def setup_method(self): + self.func = np.var + self.allowed_axes = [0, 1, 2, -1, -2, + (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)] + From 476acdbad55accb89a85d6900c64b6ef88490fa4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 10 Jan 2023 17:19:14 +0300 Subject: [PATCH 2/4] BUG: sum() of a bool array is integer --- torch_np/_helpers.py | 9 ++++++--- torch_np/_ndarray.py | 6 +++--- torch_np/tests/test_reductions.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/torch_np/_helpers.py b/torch_np/_helpers.py index 75d0f821..201f8166 100644 --- a/torch_np/_helpers.py +++ b/torch_np/_helpers.py @@ -133,11 +133,14 @@ def to_tensors(*inputs): for value in inputs]) -def float_or_default(dtype, self_dtype): +def float_or_default(dtype, self_dtype, enforce_float=False): """dtype helper for reductions.""" if dtype is None: dtype = self_dtype - if _dtypes.is_integer(dtype): - dtype = _dtypes.default_float_type() + if dtype == _dtypes.dtype('bool'): + dtype = _dtypes.default_int_type() + if enforce_float: + if _dtypes.is_integer(dtype): + dtype = _dtypes.default_float_type() torch_dtype = _dtypes.torch_dtype_from(dtype) return torch_dtype diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index b9e8bb2c..00406248 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -416,7 +416,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal if where is not None: raise NotImplementedError - torch_dtype = _helpers.float_or_default(dtype, self.dtype) + torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True) if axis is None: result = self._tensor.mean(dtype=torch_dtype) else: @@ -462,7 +462,7 @@ def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, if where is not None: raise NotImplementedError - torch_dtype = _helpers.float_or_default(dtype, self.dtype) + torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True) tensor = self._tensor.to(torch_dtype) # XXX: needed? result = tensor.std(dim=axis, correction=ddof) @@ -475,7 +475,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, if where is not None: raise NotImplementedError - torch_dtype = _helpers.float_or_default(dtype, self.dtype) + torch_dtype = _helpers.float_or_default(dtype, self.dtype, enforce_float=True) tensor = self._tensor.to(torch_dtype) # XXX: needed? result = tensor.var(dim=axis, correction=ddof) diff --git a/torch_np/tests/test_reductions.py b/torch_np/tests/test_reductions.py index e0ef345c..06714667 100644 --- a/torch_np/tests/test_reductions.py +++ b/torch_np/tests/test_reductions.py @@ -328,6 +328,16 @@ def test_sum_stability(self): assert_allclose((a / 10.).sum() - a.size / 10., 0., atol=1.5e-13, check_dtype=False) + def test_sum_boolean(self): + a = (np.arange(7) % 2 == 0) + res = a.sum() + assert_equal(res, 4) + + res_float = a.sum(dtype=np.float64) + assert_allclose(res_float, 4.0, atol=1e-15) + assert res_float.dtype == 'float64' + + @pytest.mark.xfail(reason="dtype(value) needs implementing") def test_sum_dtypes(self): for dt in (int, np.float16, np.float32, np.float64): From 1df46190ca457649baffcda2565328c90e625cc0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 10 Jan 2023 19:59:39 +0300 Subject: [PATCH 3/4] MAINT: make .real/imag attributes writeable --- torch_np/_ndarray.py | 8 ++++ torch_np/_wrapper.py | 5 ++ torch_np/testing/__init__.py | 4 +- torch_np/testing/testing.py | 33 ------------- torch_np/testing/utils.py | 9 ++-- torch_np/tests/test_reductions.py | 79 ++++++++++++++++--------------- 6 files changed, 63 insertions(+), 75 deletions(-) delete mode 100644 torch_np/testing/testing.py diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 00406248..888413f2 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -86,6 +86,10 @@ def T(self): def real(self): return asarray(self._tensor.real) + @real.setter + def real(self, value): + self._tensor.real = asarray(value).get() + @property def imag(self): try: @@ -94,6 +98,10 @@ def imag(self): zeros = torch.zeros_like(self._tensor) return ndarray._from_tensor_and_base(zeros, None) + @imag.setter + def imag(self, value): + self._tensor.imag = asarray(value).get() + # ctors def astype(self, dtype): newt = ndarray() diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 92592cc0..045a3089 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -757,6 +757,11 @@ def isscalar(a): return False +def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): + a = asarray(a).get() + b = asarray(a).get() + return asarray(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) + ###### mapping from numpy API objects to wrappers from this module ###### # All is in the mapping dict in _mapping.py diff --git a/torch_np/testing/__init__.py b/torch_np/testing/__init__.py index 98755a5b..03838722 100644 --- a/torch_np/testing/__init__.py +++ b/torch_np/testing/__init__.py @@ -1,7 +1,7 @@ from .utils import (assert_equal, assert_array_equal, assert_almost_equal, - assert_warns, assert_) + assert_warns, assert_, assert_allclose) from .utils import _gen_alignment_data -from .testing import assert_allclose # FIXME +#from .testing import assert_allclose # FIXME diff --git a/torch_np/testing/testing.py b/torch_np/testing/testing.py deleted file mode 100644 index fbe0d34b..00000000 --- a/torch_np/testing/testing.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch - -from .._ndarray import asarray -import torch_np as np - -def assert_allclose(actual, desired, rtol=1e-07, atol=0, equal_nan=True, - err_msg='', verbose=True, check_dtype=True): - actual = asarray(actual).get() - desired = asarray(desired).get() - result = torch.testing.assert_close(actual, desired, atol=atol, rtol=rtol, - check_dtype=check_dtype) - return result - - -def assert_equal(actual, desired): - """Check `actual == desired`, broadcast if needed """ - actual = np.asarray(actual) - desired = np.asarray(desired) - eq = np.all(actual == desired) - if not eq: - raise AssertionError('not equal') - return eq - - - -def assert_array_equal(actual, desired): - """Check that actual == desired, both shapes and values.""" - a_actual = asarray(actual) - a_desired = asarray(desired) - - assert a_actual.shape == a_desired.shape - assert (a_actual == a_desired).all() - diff --git a/torch_np/testing/utils.py b/torch_np/testing/utils.py index e33d5acd..5525e8e8 100644 --- a/torch_np/testing/utils.py +++ b/torch_np/testing/utils.py @@ -1169,7 +1169,7 @@ def _assert_valid_refcount(op): def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, - err_msg='', verbose=True): + err_msg='', verbose=True, check_dtype=False): """ Raises an AssertionError if two objects are not equal up to desired tolerance. @@ -1226,14 +1226,17 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, """ __tracebackhide__ = True # Hide traceback for py.test - import numpy as np def compare(x, y): - return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol, + return np.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) actual, desired = asanyarray(actual), asanyarray(desired) header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}' + + if check_dtype: + assert actual.dtype == desired.dtype + assert_array_compare(compare, actual, desired, err_msg=str(err_msg), verbose=verbose, header=header, equal_nan=equal_nan) diff --git a/torch_np/tests/test_reductions.py b/torch_np/tests/test_reductions.py index 06714667..b996f0d5 100644 --- a/torch_np/tests/test_reductions.py +++ b/torch_np/tests/test_reductions.py @@ -2,7 +2,8 @@ from pytest import raises as assert_raises import torch_np as np -from torch_np.testing import assert_equal, assert_array_equal, assert_allclose +from torch_np.testing import (assert_equal, assert_array_equal, assert_allclose, + assert_almost_equal) import torch_np._util as _util @@ -321,12 +322,10 @@ def test_sum(self): def test_sum_stability(self): a = np.ones(500, dtype=np.float32) zero = np.zeros(1, dtype='float32')[0] - assert_allclose((a / 10.).sum() - a.size / 10., zero, atol=1.5e-4, - check_dtype=False) + assert_allclose((a / 10.).sum() - a.size / 10., zero, atol=1.5e-4) a = np.ones(500, dtype=np.float64) - assert_allclose((a / 10.).sum() - a.size / 10., 0., atol=1.5e-13, - check_dtype=False) + assert_allclose((a / 10.).sum() - a.size / 10., 0., atol=1.5e-13) def test_sum_boolean(self): a = (np.arange(7) % 2 == 0) @@ -338,8 +337,8 @@ def test_sum_boolean(self): assert res_float.dtype == 'float64' - @pytest.mark.xfail(reason="dtype(value) needs implementing") - def test_sum_dtypes(self): + @pytest.mark.xfail(reason="sum: does not warn on overflow") + def test_sum_dtypes_warnings(self): for dt in (int, np.float16, np.float32, np.float64): for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, 128, 1024, 1235): @@ -357,48 +356,54 @@ def test_sum_dtypes(self): assert_almost_equal(np.sum(d), tgt) assert_equal(len(w), 2 * overflow) - assert_almost_equal(np.sum(d[::-1]), tgt) + assert_almost_equal(np.sum(np.flip(d)), tgt) assert_equal(len(w), 3 * overflow) + def test_sum_dtypes_2(self): + for dt in (int, np.float16, np.float32, np.float64): d = np.ones(500, dtype=dt) assert_almost_equal(np.sum(d[::2]), 250.) assert_almost_equal(np.sum(d[1::2]), 250.) assert_almost_equal(np.sum(d[::3]), 167.) assert_almost_equal(np.sum(d[1::3]), 167.) - assert_almost_equal(np.sum(d[::-2]), 250.) - assert_almost_equal(np.sum(d[-1::-2]), 250.) - assert_almost_equal(np.sum(d[::-3]), 167.) - assert_almost_equal(np.sum(d[-1::-3]), 167.) + assert_almost_equal(np.sum(np.flip(d)[::2]), 250.) + + assert_almost_equal(np.sum(np.flip(d)[1::2]), 250.) + + assert_almost_equal(np.sum(np.flip(d)[::3]), 167.) + assert_almost_equal(np.sum(np.flip(d)[1::3]), 167.) + # sum with first reduction entry != 0 d = np.ones((1,), dtype=dt) d += d assert_almost_equal(d, 2.) - @pytest.mark.xfail(reason="dtype(value) needs implementing") - def test_sum_complex(self): - for dt in (np.complex64, np.complex128): - for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, - 128, 1024, 1235): - tgt = dt(v * (v + 1) / 2) - dt((v * (v + 1) / 2) * 1j) - d = np.empty(v, dtype=dt) - d.real = np.arange(1, v + 1) - d.imag = -np.arange(1, v + 1) - assert_allclose(np.sum(d), tgt, atol=1.5e-5) - assert_allcllose(np.sum(d[::-1]), tgt, atol=1.5e-7) - - d = np.ones(500, dtype=dt) + 1j - assert_allclose(np.sum(d[::2]), 250. + 250j, atol=1.5e-7) - assert_allclose(np.sum(d[1::2]), 250. + 250j, atol=1.5e-7) - assert_allclose(np.sum(d[::3]), 167. + 167j, atol=1.5e-7) - assert_allclose(np.sum(d[1::3]), 167. + 167j, atol=1.5e-7) - assert_allclose(np.sum(d[::-2]), 250. + 250j, atol=1.5e-7) - assert_allclose(np.sum(d[-1::-2]), 250. + 250j, atol=1.5e-7) - assert_allclose(np.sum(d[::-3]), 167. + 167j, atol=1.5e-7) - assert_allclose(np.sum(d[-1::-3]), 167. + 167j, atol=1.5e-7) - # sum with first reduction entry != 0 - d = np.ones((1,), dtype=dt) + 1j - d += d - assert_allclose(d, 2. + 2j, atol=1.5e-7) + @pytest.mark.parametrize("dt", [np.complex64, np.complex128]) + def test_sum_complex_1(self, dt): + for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, + 128, 1024, 1235): + tgt = dt(v * (v + 1) / 2) - dt((v * (v + 1) / 2) * 1j) + d = np.empty(v, dtype=dt) + d.real = np.arange(1, v + 1) + d.imag = -np.arange(1, v + 1) + assert_allclose(np.sum(d), tgt, atol=1.5e-5) + assert_allclose(np.sum(np.flip(d)), tgt, atol=1.5e-7) + + @pytest.mark.parametrize("dt", [np.complex64, np.complex128]) + def test_sum_complex_2(self, dt): + d = np.ones(500, dtype=dt) + 1j + assert_allclose(np.sum(d[::2]), 250. + 250j, atol=1.5e-7) + assert_allclose(np.sum(d[1::2]), 250. + 250j, atol=1.5e-7) + assert_allclose(np.sum(d[::3]), 167. + 167j, atol=1.5e-7) + assert_allclose(np.sum(d[1::3]), 167. + 167j, atol=1.5e-7) + assert_allclose(np.sum(np.flip(d)[::2]), 250. + 250j, atol=1.5e-7) + assert_allclose(np.sum(np.flip(d)[1::2]), 250. + 250j, atol=1.5e-7) + assert_allclose(np.sum(np.flip(d)[::3]), 167. + 167j, atol=1.5e-7) + assert_allclose(np.sum(np.flip(d)[1::3]), 167. + 167j, atol=1.5e-7) + # sum with first reduction entry != 0 + d = np.ones((1,), dtype=dt) + 1j + d += d + assert_allclose(d, 2. + 2j, atol=1.5e-7) @pytest.mark.xfail(reason='initial=... need implementing') def test_sum_initial(self): From a592826829fca10cfdf062f199293ea86eeb72cd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 12 Jan 2023 20:32:26 +0600 Subject: [PATCH 4/4] MAINT: golf around axis being None or not None Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> --- torch_np/_ndarray.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 888413f2..179c136e 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -456,12 +456,8 @@ def prod(self, axis=None, dtype=None, out=None, keepdims=NoValue, axis = _helpers.allow_only_single_axis(axis) torch_dtype = _helpers.float_or_default(dtype, self.dtype) - if axis is None: - result = self._tensor.prod(dtype=torch_dtype) - else: - result = self._tensor.prod(dtype=torch_dtype, dim=axis) - - return result +kwargs = {"dim": axis} if axis is not None else {} +return self._tensor.prod(dtype=torch_dtype, **kwargs) @axis_out_keepdims_wrapper