From d16d36a8a8733b8650e868bb49d4194a8a93481a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Feb 2023 13:03:12 +0300 Subject: [PATCH 1/4] ENH: implement split family (split, {array_, v, h, d}split) --- autogen/numpy_api_dump.py | 20 ------- torch_np/_detail/implementations.py | 60 +++++++++++++++++++ torch_np/_ndarray.py | 4 ++ torch_np/_wrapper.py | 57 ++++++++++++++++++ .../tests/numpy_tests/lib/test_shape_base_.py | 13 ++-- 5 files changed, 125 insertions(+), 29 deletions(-) diff --git a/autogen/numpy_api_dump.py b/autogen/numpy_api_dump.py index 4d4b8fab..e5f5fdb7 100644 --- a/autogen/numpy_api_dump.py +++ b/autogen/numpy_api_dump.py @@ -110,10 +110,6 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): raise NotImplementedError -def array_split(ary, indices_or_sections, axis=0): - raise NotImplementedError - - def array_str(a, max_line_width=None, precision=None, suppress_small=None): raise NotImplementedError @@ -260,10 +256,6 @@ def dot(a, b, out=None): raise NotImplementedError -def dsplit(ary, indices_or_sections): - raise NotImplementedError - - def ediff1d(ary, to_end=None, to_begin=None): raise NotImplementedError @@ -417,10 +409,6 @@ def histogramdd(sample, bins=10, range=None, normed=None, weights=None, density= raise NotImplementedError -def hsplit(ary, indices_or_sections): - raise NotImplementedError - - def in1d(ar1, ar2, assume_unique=False, invert=False): raise NotImplementedError @@ -875,10 +863,6 @@ def sort_complex(a): raise NotImplementedError -def split(ary, indices_or_sections, axis=0): - raise NotImplementedError - - def swapaxes(a, axis1, axis2): raise NotImplementedError @@ -947,10 +931,6 @@ def vdot(a, b, /): raise NotImplementedError -def vsplit(ary, indices_or_sections): - raise NotImplementedError - - def where(condition, x, y, /): raise NotImplementedError diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index b92432bb..bd33a806 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -15,3 +15,63 @@ def tensor_equal(a1_t, a2_t, equal_nan=False): else: result = a1_t == a2_t return bool(result.all()) + + +def split_helper(tensor, indices_or_sections, axis, strict=False): + if isinstance(indices_or_sections, int): + return split_helper_int(tensor, indices_or_sections, axis, strict) + elif isinstance(indices_or_sections, (list, tuple)): + return split_helper_list(tensor, list(indices_or_sections), axis, strict) + else: + raise TypeError("split_helper: ", type(indices_or_sections)) + + +def split_helper_int(tensor, indices_or_sections, axis, strict=False): + if not isinstance(indices_or_sections, int): + raise NotImplementedError("split: indices_or_sections") + + # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n + l, n = tensor.shape[axis], indices_or_sections + + if n <= 0: + raise ValueError() + + if l % n == 0: + num, sz = n, l // n + lst = [sz] * num + else: + if strict: + raise ValueError("array split does not result in an equal division") + + num, sz = l % n, l // n + 1 + lst = [sz] * num + + if n > l: + lst += [0] * (n - l) + else: + lrest = l - num * sz + + sz_1 = sz - 1 + num_1 = lrest // sz_1 + lst += [sz_1] * num_1 + + result = torch.split(tensor, lst, axis) + + return result + + +def split_helper_list(tensor, indices_or_sections, axis, strict=False): + if not isinstance(indices_or_sections, list): + raise NotImplementedError("split: indices_or_sections: list") + # numpy expectes indices, while torch expects lengths of sections + # also, numpy appends zero-size arrays for indices above the shape[axis] + lst = [x for x in indices_or_sections if x <= tensor.shape[axis]] + num_extra = len(indices_or_sections) - len(lst) + + lst = lst + [tensor.shape[axis]] + lst = [ + lst[0], + ] + [a - b for a, b in zip(lst[1:], lst[:-1])] + lst += [0] * num_extra + + return torch.split(tensor, lst, axis) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 746fe4c5..dde3cd7e 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -409,6 +409,10 @@ def asarray(a, dtype=None, order=None, *, like=None): return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) +def maybe_set_base(tensor, base): + return ndarray._from_tensor_and_base(tensor, base) + + class asarray_replacer: def __init__(self, dispatch="one"): if dispatch not in ["one", "two"]: diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 755efd01..af69cb4a 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -14,6 +14,7 @@ asarray, asarray_replacer, can_cast, + maybe_set_base, ndarray, newaxis, result_type, @@ -158,6 +159,62 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"): ) +def array_split(ary, indices_or_sections, axis=0): + tensor = asarray(ary).get() + base = ary if isinstance(ary, ndarray) else None + axis = _util.normalize_axis_index(axis, tensor.ndim) + + result = _impl.split_helper(tensor, indices_or_sections, axis) + + return tuple(maybe_set_base(_, base) for _ in result) + + +def split(ary, indices_or_sections, axis=0): + tensor = asarray(ary).get() + base = ary if isinstance(ary, ndarray) else None + axis = _util.normalize_axis_index(axis, tensor.ndim) + + result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True) + + return tuple(maybe_set_base(_, base) for _ in result) + + +def hsplit(ary, indices_or_sections): + tensor = asarray(ary).get() + base = ary if isinstance(ary, ndarray) else None + + if tensor.ndim == 0: + raise ValueError("hsplit only works on arrays of 1 or more dimensions") + + axis = 1 if tensor.ndim > 1 else 0 + + result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True) + + return tuple(maybe_set_base(_, base) for _ in result) + + +def vsplit(ary, indices_or_sections): + tensor = asarray(ary).get() + base = ary if isinstance(ary, ndarray) else None + + if tensor.ndim < 2: + raise ValueError("vsplit only works on arrays of 2 or more dimensions") + result = _impl.split_helper(tensor, indices_or_sections, 0, strict=True) + + return tuple(maybe_set_base(_, base) for _ in result) + + +def dsplit(ary, indices_or_sections): + tensor = asarray(ary).get() + base = ary if isinstance(ary, ndarray) else None + + if tensor.ndim < 3: + raise ValueError("dsplit only works on arrays of 3 or more dimensions") + result = _impl.split_helper(tensor, indices_or_sections, 2, strict=True) + + return tuple(maybe_set_base(_, base) for _ in result) + + def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): if axis != 0 or retstep or not endpoint: raise NotImplementedError diff --git a/torch_np/tests/numpy_tests/lib/test_shape_base_.py b/torch_np/tests/numpy_tests/lib/test_shape_base_.py index a44a1b5c..263028bf 100644 --- a/torch_np/tests/numpy_tests/lib/test_shape_base_.py +++ b/torch_np/tests/numpy_tests/lib/test_shape_base_.py @@ -2,12 +2,12 @@ import sys import pytest -from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, array_split, - split, hsplit, dsplit, vsplit, kron, tile, - expand_dims, take_along_axis, put_along_axis) +from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, kron, tile, + take_along_axis, put_along_axis) import torch_np as np -from torch_np import column_stack, dstack, expand_dims +from torch_np import (column_stack, dstack, expand_dims, array_split, + split, hsplit, dsplit, vsplit,) from torch_np.random import rand @@ -275,7 +275,6 @@ def test_repeated_axis(self): assert_raises(ValueError, expand_dims, a, axis=(1, 1)) -@pytest.mark.xfail(reason="TODO: implement") class TestArraySplit: def test_integer_0_split(self): a = np.arange(10) @@ -410,7 +409,6 @@ def test_index_split_high_bound(self): compare_results(res, desired) -@pytest.mark.xfail(reason="TODO: implement") class TestSplit: # The split function is essentially the same as array_split, # except that it test if splitting will result in an @@ -493,7 +491,6 @@ def test_generator(self): # array_split has more comprehensive test of splitting. # only do simple test on hsplit, vsplit, and dsplit -@pytest.mark.xfail(reason="TODO: implement") class TestHsplit: """Only testing for integer splits. @@ -523,7 +520,6 @@ def test_2D_array(self): compare_results(res, desired) -@pytest.mark.xfail(reason="TODO: implement") class TestVsplit: """Only testing for integer splits. @@ -551,7 +547,6 @@ def test_2D_array(self): compare_results(res, desired) -@pytest.mark.xfail(reason="TODO: implement") class TestDsplit: # Only testing for integer splits. def test_non_iterable(self): From 51a8d12dc317d83df1d43a00c5e50407f8111ee5 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Feb 2023 19:58:56 +0300 Subject: [PATCH 2/4] ENH: tile, kron --- autogen/numpy_api_dump.py | 4 ---- torch_np/_wrapper.py | 15 ++++++++++++++ .../tests/numpy_tests/lib/test_shape_base_.py | 20 +++---------------- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/autogen/numpy_api_dump.py b/autogen/numpy_api_dump.py index e5f5fdb7..51972913 100644 --- a/autogen/numpy_api_dump.py +++ b/autogen/numpy_api_dump.py @@ -481,10 +481,6 @@ def kaiser(M, beta): raise NotImplementedError -def kron(a, b): - raise NotImplementedError - - def lexsort(keys, axis=-1): raise NotImplementedError diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index af69cb4a..03e0c239 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -215,6 +215,21 @@ def dsplit(ary, indices_or_sections): return tuple(maybe_set_base(_, base) for _ in result) +def kron(a, b): + a_tensor, b_tensor = _helpers.to_tensors(a, b) + result = torch.kron(a_tensor, b_tensor) + return asarray(result) + + +def tile(A, reps): + a_tensor = asarray(A).get() + if isinstance(reps, int): + reps = (reps,) + + result = torch.tile(a_tensor, reps) + return asarray(result) + + def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): if axis != 0 or retstep or not endpoint: raise NotImplementedError diff --git a/torch_np/tests/numpy_tests/lib/test_shape_base_.py b/torch_np/tests/numpy_tests/lib/test_shape_base_.py index 263028bf..3400f4b2 100644 --- a/torch_np/tests/numpy_tests/lib/test_shape_base_.py +++ b/torch_np/tests/numpy_tests/lib/test_shape_base_.py @@ -2,14 +2,14 @@ import sys import pytest -from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, kron, tile, +from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, take_along_axis, put_along_axis) import torch_np as np from torch_np import (column_stack, dstack, expand_dims, array_split, - split, hsplit, dsplit, vsplit,) + split, hsplit, dsplit, vsplit, kron, tile,) -from torch_np.random import rand +from torch_np.random import rand, randint from torch_np.testing import assert_array_equal, assert_equal, assert_ from pytest import raises as assert_raises @@ -635,7 +635,6 @@ def test_squeeze_axis_handling(self): np.squeeze(np.array([[1], [2], [3]]), axis=0) -@pytest.mark.xfail(reason="TODO: implement") class TestKron: def test_basic(self): # Using 0-dimensional ndarray @@ -666,16 +665,6 @@ def test_basic(self): k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) assert_array_equal(np.kron(a, b), k) - def test_return_type(self): - class myarray(np.ndarray): - __array_priority__ = 1.0 - - a = np.ones([2, 2]) - ma = myarray(a.shape, a.dtype, a.data) - assert_equal(type(kron(a, a)), np.ndarray) - assert_equal(type(kron(ma, ma)), myarray) - assert_equal(type(kron(a, ma)), myarray) - assert_equal(type(kron(ma, a)), myarray) @pytest.mark.parametrize( "shape_a,shape_b", [ @@ -698,7 +687,6 @@ def test_kron_shape(self, shape_a, shape_b): k.shape, expected_shape), "Unexpected shape from kron" -@pytest.mark.xfail(reason="TODO: implement") class TestTile: def test_basic(self): a = np.array([0, 1, 2]) @@ -726,8 +714,6 @@ def test_empty(self): assert_equal(d, (3, 2, 0)) def test_kroncompare(self): - from numpy.random import randint - reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)] shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)] for s in shape: From 3dfb8724be4f99cb44d899f8095adea31c96c662 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Feb 2023 20:21:44 +0300 Subject: [PATCH 3/4] Update torch_np/_detail/implementations.py Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> --- torch_np/_detail/implementations.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index bd33a806..e45cd4d8 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -46,14 +46,7 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False): num, sz = l % n, l // n + 1 lst = [sz] * num - if n > l: - lst += [0] * (n - l) - else: - lrest = l - num * sz - - sz_1 = sz - 1 - num_1 = lrest // sz_1 - lst += [sz_1] * num_1 + lst += [sz - 1] * (n - num) result = torch.split(tensor, lst, axis) From a467896dc57bfbed605bbfbd6a8b329fb747a291 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Feb 2023 20:23:54 +0300 Subject: [PATCH 4/4] MAINT: apply suggestions from review --- autogen/numpy_api_dump.py | 4 ---- torch_np/_detail/implementations.py | 4 ++-- torch_np/_wrapper.py | 10 +++++----- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/autogen/numpy_api_dump.py b/autogen/numpy_api_dump.py index 51972913..67c0c981 100644 --- a/autogen/numpy_api_dump.py +++ b/autogen/numpy_api_dump.py @@ -875,10 +875,6 @@ def tensordot(a, b, axes=2): raise NotImplementedError -def tile(A, reps): - raise NotImplementedError - - def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None): raise NotImplementedError diff --git a/torch_np/_detail/implementations.py b/torch_np/_detail/implementations.py index e45cd4d8..8dc3301d 100644 --- a/torch_np/_detail/implementations.py +++ b/torch_np/_detail/implementations.py @@ -46,7 +46,7 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False): num, sz = l % n, l // n + 1 lst = [sz] * num - lst += [sz - 1] * (n - num) + lst += [sz - 1] * (n - num) result = torch.split(tensor, lst, axis) @@ -61,7 +61,7 @@ def split_helper_list(tensor, indices_or_sections, axis, strict=False): lst = [x for x in indices_or_sections if x <= tensor.shape[axis]] num_extra = len(indices_or_sections) - len(lst) - lst = lst + [tensor.shape[axis]] + lst.append(tensor.shape[axis]) lst = [ lst[0], ] + [a - b for a, b in zip(lst[1:], lst[:-1])] diff --git a/torch_np/_wrapper.py b/torch_np/_wrapper.py index 03e0c239..06fcb077 100644 --- a/torch_np/_wrapper.py +++ b/torch_np/_wrapper.py @@ -166,7 +166,7 @@ def array_split(ary, indices_or_sections, axis=0): result = _impl.split_helper(tensor, indices_or_sections, axis) - return tuple(maybe_set_base(_, base) for _ in result) + return tuple(maybe_set_base(x, base) for x in result) def split(ary, indices_or_sections, axis=0): @@ -176,7 +176,7 @@ def split(ary, indices_or_sections, axis=0): result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True) - return tuple(maybe_set_base(_, base) for _ in result) + return tuple(maybe_set_base(x, base) for x in result) def hsplit(ary, indices_or_sections): @@ -190,7 +190,7 @@ def hsplit(ary, indices_or_sections): result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True) - return tuple(maybe_set_base(_, base) for _ in result) + return tuple(maybe_set_base(x, base) for x in result) def vsplit(ary, indices_or_sections): @@ -201,7 +201,7 @@ def vsplit(ary, indices_or_sections): raise ValueError("vsplit only works on arrays of 2 or more dimensions") result = _impl.split_helper(tensor, indices_or_sections, 0, strict=True) - return tuple(maybe_set_base(_, base) for _ in result) + return tuple(maybe_set_base(x, base) for x in result) def dsplit(ary, indices_or_sections): @@ -212,7 +212,7 @@ def dsplit(ary, indices_or_sections): raise ValueError("dsplit only works on arrays of 3 or more dimensions") result = _impl.split_helper(tensor, indices_or_sections, 2, strict=True) - return tuple(maybe_set_base(_, base) for _ in result) + return tuple(maybe_set_base(x, base) for x in result) def kron(a, b):