@@ -158,6 +158,88 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
158
158
)
159
159
160
160
161
+ def array_split (ary , indices_or_sections , axis = 0 ):
162
+ tensor = asarray (ary ).get ()
163
+ axis = _util .normalize_axis_index (axis , tensor .ndim )
164
+
165
+ result = _split_helper (tensor , indices_or_sections , axis )
166
+
167
+ return tuple (asarray (_ ) for _ in result )
168
+
169
+
170
+ def split (ary , indices_or_sections , axis = 0 ):
171
+ tensor = asarray (ary ).get ()
172
+ axis = _util .normalize_axis_index (axis , tensor .ndim )
173
+
174
+ result = _split_helper (tensor , indices_or_sections , axis , strict = True )
175
+
176
+ return tuple (asarray (_ ) for _ in result )
177
+
178
+
179
+ def hsplit (ary , indices_or_sections ):
180
+ tensor = asarray (ary ).get ()
181
+
182
+ if tensor .ndim == 0 :
183
+ raise ValueError ('hsplit only works on arrays of 1 or more dimensions' )
184
+
185
+ axis = 1 if tensor .ndim > 1 else 0
186
+
187
+ result = _split_helper (tensor , indices_or_sections , axis , strict = True )
188
+
189
+ return tuple (asarray (_ ) for _ in result )
190
+
191
+
192
+ def vsplit (ary , indices_or_sections ):
193
+ tensor = asarray (ary ).get ()
194
+
195
+ if tensor .ndim < 2 :
196
+ raise ValueError ('vsplit only works on arrays of 2 or more dimensions' )
197
+ result = _split_helper (tensor , indices_or_sections , 0 , strict = True )
198
+
199
+ return tuple (asarray (_ ) for _ in result )
200
+
201
+
202
+ def dsplit (ary , indices_or_sections ):
203
+ tensor = asarray (ary ).get ()
204
+
205
+ if tensor .ndim < 3 :
206
+ raise ValueError ('dsplit only works on arrays of 3 or more dimensions' )
207
+ result = _split_helper (tensor , indices_or_sections , 2 , strict = True )
208
+
209
+ return tuple (asarray (_ ) for _ in result )
210
+
211
+
212
+ def _split_helper (tensor , indices_or_sections , axis , strict = False ):
213
+ if not isinstance (indices_or_sections , int ):
214
+ raise NotImplementedError ('split: indices_or_sections' )
215
+
216
+ # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
217
+ l , n = tensor .shape [axis ], indices_or_sections
218
+
219
+ if n <= 0 :
220
+ raise ValueError ()
221
+
222
+ if l % n == 0 :
223
+ num , sz = n , l // n
224
+ lst = [sz ] * num
225
+ else :
226
+ if strict :
227
+ raise ValueError ("array split does not result in an equal division" )
228
+
229
+ num , sz = l % n , l // n + 1
230
+ lst = [sz ] * num
231
+
232
+ lrest = l - num * sz
233
+
234
+ sz_1 = sz - 1
235
+ num_1 = lrest // sz_1
236
+ lst += [sz_1 ]* num_1
237
+
238
+ result = torch .split (tensor , lst , axis )
239
+
240
+ return result
241
+
242
+
161
243
def linspace (start , stop , num = 50 , endpoint = True , retstep = False , dtype = None , axis = 0 ):
162
244
if axis != 0 or retstep or not endpoint :
163
245
raise NotImplementedError
0 commit comments