Skip to content

WIP: implement {array_, VHD}split #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions autogen/numpy_api_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
raise NotImplementedError


def array_split(ary, indices_or_sections, axis=0):
raise NotImplementedError


def array_str(a, max_line_width=None, precision=None, suppress_small=None):
raise NotImplementedError

Expand Down Expand Up @@ -260,10 +256,6 @@ def dot(a, b, out=None):
raise NotImplementedError


def dsplit(ary, indices_or_sections):
raise NotImplementedError


def ediff1d(ary, to_end=None, to_begin=None):
raise NotImplementedError

Expand Down Expand Up @@ -417,10 +409,6 @@ def histogramdd(sample, bins=10, range=None, normed=None, weights=None, density=
raise NotImplementedError


def hsplit(ary, indices_or_sections):
raise NotImplementedError


def in1d(ar1, ar2, assume_unique=False, invert=False):
raise NotImplementedError

Expand Down Expand Up @@ -493,10 +481,6 @@ def kaiser(M, beta):
raise NotImplementedError


def kron(a, b):
raise NotImplementedError


def lexsort(keys, axis=-1):
raise NotImplementedError

Expand Down Expand Up @@ -875,10 +859,6 @@ def sort_complex(a):
raise NotImplementedError


def split(ary, indices_or_sections, axis=0):
raise NotImplementedError


def swapaxes(a, axis1, axis2):
raise NotImplementedError

Expand All @@ -895,10 +875,6 @@ def tensordot(a, b, axes=2):
raise NotImplementedError


def tile(A, reps):
raise NotImplementedError


def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
raise NotImplementedError

Expand Down Expand Up @@ -947,10 +923,6 @@ def vdot(a, b, /):
raise NotImplementedError


def vsplit(ary, indices_or_sections):
raise NotImplementedError


def where(condition, x, y, /):
raise NotImplementedError

Expand Down
53 changes: 53 additions & 0 deletions torch_np/_detail/implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,56 @@ def tensor_equal(a1_t, a2_t, equal_nan=False):
else:
result = a1_t == a2_t
return bool(result.all())


def split_helper(tensor, indices_or_sections, axis, strict=False):
if isinstance(indices_or_sections, int):
return split_helper_int(tensor, indices_or_sections, axis, strict)
elif isinstance(indices_or_sections, (list, tuple)):
return split_helper_list(tensor, list(indices_or_sections), axis, strict)
else:
raise TypeError("split_helper: ", type(indices_or_sections))


def split_helper_int(tensor, indices_or_sections, axis, strict=False):
if not isinstance(indices_or_sections, int):
raise NotImplementedError("split: indices_or_sections")

# numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
l, n = tensor.shape[axis], indices_or_sections

if n <= 0:
raise ValueError()

if l % n == 0:
num, sz = n, l // n
lst = [sz] * num
else:
if strict:
raise ValueError("array split does not result in an equal division")

num, sz = l % n, l // n + 1
lst = [sz] * num

lst += [sz - 1] * (n - num)

result = torch.split(tensor, lst, axis)

return result


def split_helper_list(tensor, indices_or_sections, axis, strict=False):
if not isinstance(indices_or_sections, list):
raise NotImplementedError("split: indices_or_sections: list")
# numpy expectes indices, while torch expects lengths of sections
# also, numpy appends zero-size arrays for indices above the shape[axis]
lst = [x for x in indices_or_sections if x <= tensor.shape[axis]]
num_extra = len(indices_or_sections) - len(lst)

lst.append(tensor.shape[axis])
lst = [
lst[0],
] + [a - b for a, b in zip(lst[1:], lst[:-1])]
lst += [0] * num_extra

return torch.split(tensor, lst, axis)
4 changes: 4 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ def asarray(a, dtype=None, order=None, *, like=None):
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)


def maybe_set_base(tensor, base):
return ndarray._from_tensor_and_base(tensor, base)


class asarray_replacer:
def __init__(self, dispatch="one"):
if dispatch not in ["one", "two"]:
Expand Down
72 changes: 72 additions & 0 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
asarray,
asarray_replacer,
can_cast,
maybe_set_base,
ndarray,
newaxis,
result_type,
Expand Down Expand Up @@ -158,6 +159,77 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
)


def array_split(ary, indices_or_sections, axis=0):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None
axis = _util.normalize_axis_index(axis, tensor.ndim)

result = _impl.split_helper(tensor, indices_or_sections, axis)

return tuple(maybe_set_base(x, base) for x in result)


def split(ary, indices_or_sections, axis=0):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None
axis = _util.normalize_axis_index(axis, tensor.ndim)

result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)

return tuple(maybe_set_base(x, base) for x in result)


def hsplit(ary, indices_or_sections):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None

if tensor.ndim == 0:
raise ValueError("hsplit only works on arrays of 1 or more dimensions")

axis = 1 if tensor.ndim > 1 else 0

result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)

return tuple(maybe_set_base(x, base) for x in result)


def vsplit(ary, indices_or_sections):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None

if tensor.ndim < 2:
raise ValueError("vsplit only works on arrays of 2 or more dimensions")
result = _impl.split_helper(tensor, indices_or_sections, 0, strict=True)

return tuple(maybe_set_base(x, base) for x in result)


def dsplit(ary, indices_or_sections):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None

if tensor.ndim < 3:
raise ValueError("dsplit only works on arrays of 3 or more dimensions")
result = _impl.split_helper(tensor, indices_or_sections, 2, strict=True)

return tuple(maybe_set_base(x, base) for x in result)


def kron(a, b):
a_tensor, b_tensor = _helpers.to_tensors(a, b)
result = torch.kron(a_tensor, b_tensor)
return asarray(result)


def tile(A, reps):
a_tensor = asarray(A).get()
if isinstance(reps, int):
reps = (reps,)

result = torch.tile(a_tensor, reps)
return asarray(result)


def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
if axis != 0 or retstep or not endpoint:
raise NotImplementedError
Expand Down
29 changes: 5 additions & 24 deletions torch_np/tests/numpy_tests/lib/test_shape_base_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import sys
import pytest

from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, array_split,
split, hsplit, dsplit, vsplit, kron, tile,
expand_dims, take_along_axis, put_along_axis)
from numpy.lib.shape_base import (apply_along_axis, apply_over_axes,
take_along_axis, put_along_axis)

import torch_np as np
from torch_np import column_stack, dstack, expand_dims
from torch_np import (column_stack, dstack, expand_dims, array_split,
split, hsplit, dsplit, vsplit, kron, tile,)

from torch_np.random import rand
from torch_np.random import rand, randint

from torch_np.testing import assert_array_equal, assert_equal, assert_
from pytest import raises as assert_raises
Expand Down Expand Up @@ -275,7 +275,6 @@ def test_repeated_axis(self):
assert_raises(ValueError, expand_dims, a, axis=(1, 1))


@pytest.mark.xfail(reason="TODO: implement")
class TestArraySplit:
def test_integer_0_split(self):
a = np.arange(10)
Expand Down Expand Up @@ -410,7 +409,6 @@ def test_index_split_high_bound(self):
compare_results(res, desired)


@pytest.mark.xfail(reason="TODO: implement")
class TestSplit:
# The split function is essentially the same as array_split,
# except that it test if splitting will result in an
Expand Down Expand Up @@ -493,7 +491,6 @@ def test_generator(self):

# array_split has more comprehensive test of splitting.
# only do simple test on hsplit, vsplit, and dsplit
@pytest.mark.xfail(reason="TODO: implement")
class TestHsplit:
"""Only testing for integer splits.

Expand Down Expand Up @@ -523,7 +520,6 @@ def test_2D_array(self):
compare_results(res, desired)


@pytest.mark.xfail(reason="TODO: implement")
class TestVsplit:
"""Only testing for integer splits.

Expand Down Expand Up @@ -551,7 +547,6 @@ def test_2D_array(self):
compare_results(res, desired)


@pytest.mark.xfail(reason="TODO: implement")
class TestDsplit:
# Only testing for integer splits.
def test_non_iterable(self):
Expand Down Expand Up @@ -640,7 +635,6 @@ def test_squeeze_axis_handling(self):
np.squeeze(np.array([[1], [2], [3]]), axis=0)


@pytest.mark.xfail(reason="TODO: implement")
class TestKron:
def test_basic(self):
# Using 0-dimensional ndarray
Expand Down Expand Up @@ -671,16 +665,6 @@ def test_basic(self):
k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
assert_array_equal(np.kron(a, b), k)

def test_return_type(self):
class myarray(np.ndarray):
__array_priority__ = 1.0

a = np.ones([2, 2])
ma = myarray(a.shape, a.dtype, a.data)
assert_equal(type(kron(a, a)), np.ndarray)
assert_equal(type(kron(ma, ma)), myarray)
assert_equal(type(kron(a, ma)), myarray)
assert_equal(type(kron(ma, a)), myarray)

@pytest.mark.parametrize(
"shape_a,shape_b", [
Expand All @@ -703,7 +687,6 @@ def test_kron_shape(self, shape_a, shape_b):
k.shape, expected_shape), "Unexpected shape from kron"


@pytest.mark.xfail(reason="TODO: implement")
class TestTile:
def test_basic(self):
a = np.array([0, 1, 2])
Expand Down Expand Up @@ -731,8 +714,6 @@ def test_empty(self):
assert_equal(d, (3, 2, 0))

def test_kroncompare(self):
from numpy.random import randint

reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)]
shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)]
for s in shape:
Expand Down