Skip to content

Commit 6de03dd

Browse files
committed
MAINT: move *split logic to _impl
1 parent 7724e7b commit 6de03dd

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

torch_np/_detail/implementations.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False):
155155
if not isinstance(indices_or_sections, int):
156156
raise NotImplementedError("split: indices_or_sections")
157157

158+
axis = _util.normalize_axis_index(axis, tensor.ndim)
159+
158160
# numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
159161
l, n = tensor.shape[axis], indices_or_sections
160162

@@ -195,6 +197,26 @@ def split_helper_list(tensor, indices_or_sections, axis):
195197
return torch.split(tensor, lst, axis)
196198

197199

200+
def hsplit(tensor, indices_or_sections):
201+
if tensor.ndim == 0:
202+
raise ValueError("hsplit only works on arrays of 1 or more dimensions")
203+
axis = 1 if tensor.ndim > 1 else 0
204+
return split_helper(tensor, indices_or_sections, axis, strict=True)
205+
206+
207+
def vsplit(tensor, indices_or_sections):
208+
if tensor.ndim < 2:
209+
raise ValueError("vsplit only works on arrays of 2 or more dimensions")
210+
return split_helper(tensor, indices_or_sections, 0, strict=True)
211+
212+
213+
def dsplit(tensor, indices_or_sections):
214+
if tensor.ndim < 3:
215+
raise ValueError("dsplit only works on arrays of 3 or more dimensions")
216+
return split_helper(tensor, indices_or_sections, 2, strict=True)
217+
218+
219+
198220
def clip(tensor, t_min, t_max):
199221
if t_min is not None:
200222
t_min = torch.broadcast_to(t_min, tensor.shape)

torch_np/_wrapper.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -176,56 +176,35 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
176176
def array_split(ary, indices_or_sections, axis=0):
177177
tensor = asarray(ary).get()
178178
base = ary if isinstance(ary, ndarray) else None
179-
axis = _util.normalize_axis_index(axis, tensor.ndim)
180-
181179
result = _impl.split_helper(tensor, indices_or_sections, axis)
182-
183180
return tuple(maybe_set_base(x, base) for x in result)
184181

185182

186183
def split(ary, indices_or_sections, axis=0):
187184
tensor = asarray(ary).get()
188185
base = ary if isinstance(ary, ndarray) else None
189-
axis = _util.normalize_axis_index(axis, tensor.ndim)
190-
191186
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
192-
193187
return tuple(maybe_set_base(x, base) for x in result)
194188

195189

196190
def hsplit(ary, indices_or_sections):
197191
tensor = asarray(ary).get()
198192
base = ary if isinstance(ary, ndarray) else None
199-
200-
if tensor.ndim == 0:
201-
raise ValueError("hsplit only works on arrays of 1 or more dimensions")
202-
203-
axis = 1 if tensor.ndim > 1 else 0
204-
205-
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
206-
193+
result = _impl.hsplit(tensor, indices_or_sections)
207194
return tuple(maybe_set_base(x, base) for x in result)
208195

209196

210197
def vsplit(ary, indices_or_sections):
211198
tensor = asarray(ary).get()
212199
base = ary if isinstance(ary, ndarray) else None
213-
214-
if tensor.ndim < 2:
215-
raise ValueError("vsplit only works on arrays of 2 or more dimensions")
216-
result = _impl.split_helper(tensor, indices_or_sections, 0, strict=True)
217-
200+
result = _impl.vsplit(tensor, indices_or_sections)
218201
return tuple(maybe_set_base(x, base) for x in result)
219202

220203

221204
def dsplit(ary, indices_or_sections):
222205
tensor = asarray(ary).get()
223206
base = ary if isinstance(ary, ndarray) else None
224-
225-
if tensor.ndim < 3:
226-
raise ValueError("dsplit only works on arrays of 3 or more dimensions")
227-
result = _impl.split_helper(tensor, indices_or_sections, 2, strict=True)
228-
207+
result = _impl.dsplit(tensor, indices_or_sections)
229208
return tuple(maybe_set_base(x, base) for x in result)
230209

231210

0 commit comments

Comments
 (0)