Skip to content

Commit cdc6050

Browse files
authored
Merge pull request #46 from Quansight-Labs/splits
WIP: implement {array_, VHD}split
2 parents 12b5922 + a467896 commit cdc6050

File tree

5 files changed

+134
-52
lines changed

5 files changed

+134
-52
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,6 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
110110
raise NotImplementedError
111111

112112

113-
def array_split(ary, indices_or_sections, axis=0):
114-
raise NotImplementedError
115-
116-
117113
def array_str(a, max_line_width=None, precision=None, suppress_small=None):
118114
raise NotImplementedError
119115

@@ -260,10 +256,6 @@ def dot(a, b, out=None):
260256
raise NotImplementedError
261257

262258

263-
def dsplit(ary, indices_or_sections):
264-
raise NotImplementedError
265-
266-
267259
def ediff1d(ary, to_end=None, to_begin=None):
268260
raise NotImplementedError
269261

@@ -417,10 +409,6 @@ def histogramdd(sample, bins=10, range=None, normed=None, weights=None, density=
417409
raise NotImplementedError
418410

419411

420-
def hsplit(ary, indices_or_sections):
421-
raise NotImplementedError
422-
423-
424412
def in1d(ar1, ar2, assume_unique=False, invert=False):
425413
raise NotImplementedError
426414

@@ -493,10 +481,6 @@ def kaiser(M, beta):
493481
raise NotImplementedError
494482

495483

496-
def kron(a, b):
497-
raise NotImplementedError
498-
499-
500484
def lexsort(keys, axis=-1):
501485
raise NotImplementedError
502486

@@ -875,10 +859,6 @@ def sort_complex(a):
875859
raise NotImplementedError
876860

877861

878-
def split(ary, indices_or_sections, axis=0):
879-
raise NotImplementedError
880-
881-
882862
def swapaxes(a, axis1, axis2):
883863
raise NotImplementedError
884864

@@ -895,10 +875,6 @@ def tensordot(a, b, axes=2):
895875
raise NotImplementedError
896876

897877

898-
def tile(A, reps):
899-
raise NotImplementedError
900-
901-
902878
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
903879
raise NotImplementedError
904880

@@ -947,10 +923,6 @@ def vdot(a, b, /):
947923
raise NotImplementedError
948924

949925

950-
def vsplit(ary, indices_or_sections):
951-
raise NotImplementedError
952-
953-
954926
def where(condition, x, y, /):
955927
raise NotImplementedError
956928

torch_np/_detail/implementations.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,56 @@ def tensor_equal(a1_t, a2_t, equal_nan=False):
1515
else:
1616
result = a1_t == a2_t
1717
return bool(result.all())
18+
19+
20+
def split_helper(tensor, indices_or_sections, axis, strict=False):
21+
if isinstance(indices_or_sections, int):
22+
return split_helper_int(tensor, indices_or_sections, axis, strict)
23+
elif isinstance(indices_or_sections, (list, tuple)):
24+
return split_helper_list(tensor, list(indices_or_sections), axis, strict)
25+
else:
26+
raise TypeError("split_helper: ", type(indices_or_sections))
27+
28+
29+
def split_helper_int(tensor, indices_or_sections, axis, strict=False):
30+
if not isinstance(indices_or_sections, int):
31+
raise NotImplementedError("split: indices_or_sections")
32+
33+
# numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
34+
l, n = tensor.shape[axis], indices_or_sections
35+
36+
if n <= 0:
37+
raise ValueError()
38+
39+
if l % n == 0:
40+
num, sz = n, l // n
41+
lst = [sz] * num
42+
else:
43+
if strict:
44+
raise ValueError("array split does not result in an equal division")
45+
46+
num, sz = l % n, l // n + 1
47+
lst = [sz] * num
48+
49+
lst += [sz - 1] * (n - num)
50+
51+
result = torch.split(tensor, lst, axis)
52+
53+
return result
54+
55+
56+
def split_helper_list(tensor, indices_or_sections, axis, strict=False):
57+
if not isinstance(indices_or_sections, list):
58+
raise NotImplementedError("split: indices_or_sections: list")
59+
# numpy expectes indices, while torch expects lengths of sections
60+
# also, numpy appends zero-size arrays for indices above the shape[axis]
61+
lst = [x for x in indices_or_sections if x <= tensor.shape[axis]]
62+
num_extra = len(indices_or_sections) - len(lst)
63+
64+
lst.append(tensor.shape[axis])
65+
lst = [
66+
lst[0],
67+
] + [a - b for a, b in zip(lst[1:], lst[:-1])]
68+
lst += [0] * num_extra
69+
70+
return torch.split(tensor, lst, axis)

torch_np/_ndarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ def asarray(a, dtype=None, order=None, *, like=None):
409409
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
410410

411411

412+
def maybe_set_base(tensor, base):
413+
return ndarray._from_tensor_and_base(tensor, base)
414+
415+
412416
class asarray_replacer:
413417
def __init__(self, dispatch="one"):
414418
if dispatch not in ["one", "two"]:

torch_np/_wrapper.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
asarray,
1515
asarray_replacer,
1616
can_cast,
17+
maybe_set_base,
1718
ndarray,
1819
newaxis,
1920
result_type,
@@ -158,6 +159,77 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
158159
)
159160

160161

162+
def array_split(ary, indices_or_sections, axis=0):
163+
tensor = asarray(ary).get()
164+
base = ary if isinstance(ary, ndarray) else None
165+
axis = _util.normalize_axis_index(axis, tensor.ndim)
166+
167+
result = _impl.split_helper(tensor, indices_or_sections, axis)
168+
169+
return tuple(maybe_set_base(x, base) for x in result)
170+
171+
172+
def split(ary, indices_or_sections, axis=0):
173+
tensor = asarray(ary).get()
174+
base = ary if isinstance(ary, ndarray) else None
175+
axis = _util.normalize_axis_index(axis, tensor.ndim)
176+
177+
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
178+
179+
return tuple(maybe_set_base(x, base) for x in result)
180+
181+
182+
def hsplit(ary, indices_or_sections):
183+
tensor = asarray(ary).get()
184+
base = ary if isinstance(ary, ndarray) else None
185+
186+
if tensor.ndim == 0:
187+
raise ValueError("hsplit only works on arrays of 1 or more dimensions")
188+
189+
axis = 1 if tensor.ndim > 1 else 0
190+
191+
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
192+
193+
return tuple(maybe_set_base(x, base) for x in result)
194+
195+
196+
def vsplit(ary, indices_or_sections):
197+
tensor = asarray(ary).get()
198+
base = ary if isinstance(ary, ndarray) else None
199+
200+
if tensor.ndim < 2:
201+
raise ValueError("vsplit only works on arrays of 2 or more dimensions")
202+
result = _impl.split_helper(tensor, indices_or_sections, 0, strict=True)
203+
204+
return tuple(maybe_set_base(x, base) for x in result)
205+
206+
207+
def dsplit(ary, indices_or_sections):
208+
tensor = asarray(ary).get()
209+
base = ary if isinstance(ary, ndarray) else None
210+
211+
if tensor.ndim < 3:
212+
raise ValueError("dsplit only works on arrays of 3 or more dimensions")
213+
result = _impl.split_helper(tensor, indices_or_sections, 2, strict=True)
214+
215+
return tuple(maybe_set_base(x, base) for x in result)
216+
217+
218+
def kron(a, b):
219+
a_tensor, b_tensor = _helpers.to_tensors(a, b)
220+
result = torch.kron(a_tensor, b_tensor)
221+
return asarray(result)
222+
223+
224+
def tile(A, reps):
225+
a_tensor = asarray(A).get()
226+
if isinstance(reps, int):
227+
reps = (reps,)
228+
229+
result = torch.tile(a_tensor, reps)
230+
return asarray(result)
231+
232+
161233
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
162234
if axis != 0 or retstep or not endpoint:
163235
raise NotImplementedError

torch_np/tests/numpy_tests/lib/test_shape_base_.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
import sys
33
import pytest
44

5-
from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, array_split,
6-
split, hsplit, dsplit, vsplit, kron, tile,
7-
expand_dims, take_along_axis, put_along_axis)
5+
from numpy.lib.shape_base import (apply_along_axis, apply_over_axes,
6+
take_along_axis, put_along_axis)
87

98
import torch_np as np
10-
from torch_np import column_stack, dstack, expand_dims
9+
from torch_np import (column_stack, dstack, expand_dims, array_split,
10+
split, hsplit, dsplit, vsplit, kron, tile,)
1111

12-
from torch_np.random import rand
12+
from torch_np.random import rand, randint
1313

1414
from torch_np.testing import assert_array_equal, assert_equal, assert_
1515
from pytest import raises as assert_raises
@@ -275,7 +275,6 @@ def test_repeated_axis(self):
275275
assert_raises(ValueError, expand_dims, a, axis=(1, 1))
276276

277277

278-
@pytest.mark.xfail(reason="TODO: implement")
279278
class TestArraySplit:
280279
def test_integer_0_split(self):
281280
a = np.arange(10)
@@ -410,7 +409,6 @@ def test_index_split_high_bound(self):
410409
compare_results(res, desired)
411410

412411

413-
@pytest.mark.xfail(reason="TODO: implement")
414412
class TestSplit:
415413
# The split function is essentially the same as array_split,
416414
# except that it test if splitting will result in an
@@ -493,7 +491,6 @@ def test_generator(self):
493491

494492
# array_split has more comprehensive test of splitting.
495493
# only do simple test on hsplit, vsplit, and dsplit
496-
@pytest.mark.xfail(reason="TODO: implement")
497494
class TestHsplit:
498495
"""Only testing for integer splits.
499496
@@ -523,7 +520,6 @@ def test_2D_array(self):
523520
compare_results(res, desired)
524521

525522

526-
@pytest.mark.xfail(reason="TODO: implement")
527523
class TestVsplit:
528524
"""Only testing for integer splits.
529525
@@ -551,7 +547,6 @@ def test_2D_array(self):
551547
compare_results(res, desired)
552548

553549

554-
@pytest.mark.xfail(reason="TODO: implement")
555550
class TestDsplit:
556551
# Only testing for integer splits.
557552
def test_non_iterable(self):
@@ -640,7 +635,6 @@ def test_squeeze_axis_handling(self):
640635
np.squeeze(np.array([[1], [2], [3]]), axis=0)
641636

642637

643-
@pytest.mark.xfail(reason="TODO: implement")
644638
class TestKron:
645639
def test_basic(self):
646640
# Using 0-dimensional ndarray
@@ -671,16 +665,6 @@ def test_basic(self):
671665
k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
672666
assert_array_equal(np.kron(a, b), k)
673667

674-
def test_return_type(self):
675-
class myarray(np.ndarray):
676-
__array_priority__ = 1.0
677-
678-
a = np.ones([2, 2])
679-
ma = myarray(a.shape, a.dtype, a.data)
680-
assert_equal(type(kron(a, a)), np.ndarray)
681-
assert_equal(type(kron(ma, ma)), myarray)
682-
assert_equal(type(kron(a, ma)), myarray)
683-
assert_equal(type(kron(ma, a)), myarray)
684668

685669
@pytest.mark.parametrize(
686670
"shape_a,shape_b", [
@@ -703,7 +687,6 @@ def test_kron_shape(self, shape_a, shape_b):
703687
k.shape, expected_shape), "Unexpected shape from kron"
704688

705689

706-
@pytest.mark.xfail(reason="TODO: implement")
707690
class TestTile:
708691
def test_basic(self):
709692
a = np.array([0, 1, 2])
@@ -731,8 +714,6 @@ def test_empty(self):
731714
assert_equal(d, (3, 2, 0))
732715

733716
def test_kroncompare(self):
734-
from numpy.random import randint
735-
736717
reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)]
737718
shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)]
738719
for s in shape:

0 commit comments

Comments
 (0)