Skip to content

Commit 0539351

Browse files
committed
WIP: splits
1 parent 12b5922 commit 0539351

File tree

2 files changed

+86
-9
lines changed

2 files changed

+86
-9
lines changed

torch_np/_wrapper.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,88 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
158158
)
159159

160160

161+
def array_split(ary, indices_or_sections, axis=0):
162+
tensor = asarray(ary).get()
163+
axis = _util.normalize_axis_index(axis, tensor.ndim)
164+
165+
result = _split_helper(tensor, indices_or_sections, axis)
166+
167+
return tuple(asarray(_) for _ in result)
168+
169+
170+
def split(ary, indices_or_sections, axis=0):
171+
tensor = asarray(ary).get()
172+
axis = _util.normalize_axis_index(axis, tensor.ndim)
173+
174+
result = _split_helper(tensor, indices_or_sections, axis, strict=True)
175+
176+
return tuple(asarray(_) for _ in result)
177+
178+
179+
def hsplit(ary, indices_or_sections):
180+
tensor = asarray(ary).get()
181+
182+
if tensor.ndim == 0:
183+
raise ValueError('hsplit only works on arrays of 1 or more dimensions')
184+
185+
axis = 1 if tensor.ndim > 1 else 0
186+
187+
result = _split_helper(tensor, indices_or_sections, axis, strict=True)
188+
189+
return tuple(asarray(_) for _ in result)
190+
191+
192+
def vsplit(ary, indices_or_sections):
193+
tensor = asarray(ary).get()
194+
195+
if tensor.ndim < 2:
196+
raise ValueError('vsplit only works on arrays of 2 or more dimensions')
197+
result = _split_helper(tensor, indices_or_sections, 0, strict=True)
198+
199+
return tuple(asarray(_) for _ in result)
200+
201+
202+
def dsplit(ary, indices_or_sections):
203+
tensor = asarray(ary).get()
204+
205+
if tensor.ndim < 3:
206+
raise ValueError('dsplit only works on arrays of 3 or more dimensions')
207+
result = _split_helper(tensor, indices_or_sections, 2, strict=True)
208+
209+
return tuple(asarray(_) for _ in result)
210+
211+
212+
def _split_helper(tensor, indices_or_sections, axis, strict=False):
213+
if not isinstance(indices_or_sections, int):
214+
raise NotImplementedError('split: indices_or_sections')
215+
216+
# numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
217+
l, n = tensor.shape[axis], indices_or_sections
218+
219+
if n <= 0:
220+
raise ValueError()
221+
222+
if l % n == 0:
223+
num, sz = n, l // n
224+
lst = [sz] * num
225+
else:
226+
if strict:
227+
raise ValueError("array split does not result in an equal division")
228+
229+
num, sz = l % n, l // n + 1
230+
lst = [sz] * num
231+
232+
lrest = l - num*sz
233+
234+
sz_1 = sz - 1
235+
num_1 = lrest // sz_1
236+
lst += [sz_1]*num_1
237+
238+
result = torch.split(tensor, lst, axis)
239+
240+
return result
241+
242+
161243
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
162244
if axis != 0 or retstep or not endpoint:
163245
raise NotImplementedError

torch_np/tests/numpy_tests/lib/test_shape_base_.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import sys
33
import pytest
44

5-
from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, array_split,
6-
split, hsplit, dsplit, vsplit, kron, tile,
7-
expand_dims, take_along_axis, put_along_axis)
5+
from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, kron, tile,
6+
take_along_axis, put_along_axis)
87

98
import torch_np as np
10-
from torch_np import column_stack, dstack, expand_dims
9+
from torch_np import (column_stack, dstack, expand_dims, array_split,
10+
split, hsplit, dsplit, vsplit,)
1111

1212
from torch_np.random import rand
1313

@@ -275,7 +275,6 @@ def test_repeated_axis(self):
275275
assert_raises(ValueError, expand_dims, a, axis=(1, 1))
276276

277277

278-
@pytest.mark.xfail(reason="TODO: implement")
279278
class TestArraySplit:
280279
def test_integer_0_split(self):
281280
a = np.arange(10)
@@ -410,7 +409,6 @@ def test_index_split_high_bound(self):
410409
compare_results(res, desired)
411410

412411

413-
@pytest.mark.xfail(reason="TODO: implement")
414412
class TestSplit:
415413
# The split function is essentially the same as array_split,
416414
# except that it test if splitting will result in an
@@ -493,7 +491,6 @@ def test_generator(self):
493491

494492
# array_split has more comprehensive test of splitting.
495493
# only do simple test on hsplit, vsplit, and dsplit
496-
@pytest.mark.xfail(reason="TODO: implement")
497494
class TestHsplit:
498495
"""Only testing for integer splits.
499496
@@ -523,7 +520,6 @@ def test_2D_array(self):
523520
compare_results(res, desired)
524521

525522

526-
@pytest.mark.xfail(reason="TODO: implement")
527523
class TestVsplit:
528524
"""Only testing for integer splits.
529525
@@ -551,7 +547,6 @@ def test_2D_array(self):
551547
compare_results(res, desired)
552548

553549

554-
@pytest.mark.xfail(reason="TODO: implement")
555550
class TestDsplit:
556551
# Only testing for integer splits.
557552
def test_non_iterable(self):

0 commit comments

Comments
 (0)