@@ -138,7 +138,8 @@ def split_helper(tensor, indices_or_sections, axis, strict=False):
138
138
if isinstance (indices_or_sections , int ):
139
139
return split_helper_int (tensor , indices_or_sections , axis , strict )
140
140
elif isinstance (indices_or_sections , (list , tuple )):
141
- return split_helper_list (tensor , list (indices_or_sections ), axis , strict )
141
+ # NB: drop split=..., it only applies to split_helper_int
142
+ return split_helper_list (tensor , list (indices_or_sections ), axis )
142
143
else :
143
144
raise TypeError ("split_helper: " , type (indices_or_sections ))
144
145
@@ -170,7 +171,7 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False):
170
171
return result
171
172
172
173
173
- def split_helper_list (tensor , indices_or_sections , axis , strict = False ):
174
+ def split_helper_list (tensor , indices_or_sections , axis ):
174
175
if not isinstance (indices_or_sections , list ):
175
176
raise NotImplementedError ("split: indices_or_sections: list" )
176
177
# numpy expectes indices, while torch expects lengths of sections
@@ -381,3 +382,44 @@ def bincount(x_tensor, /, weights_tensor=None, minlength=0):
381
382
382
383
result = torch .bincount (x_tensor , weights_tensor , minlength )
383
384
return result
385
+
386
+
387
+ # ### linspace, geomspace, logspace and arange ###
388
+
389
+ def geomspace (start , stop , num = 50 , endpoint = True , dtype = None , axis = 0 ):
390
+ if axis != 0 or not endpoint :
391
+ raise NotImplementedError
392
+ tstart , tstop = torch .as_tensor ([start , stop ])
393
+ base = torch .pow (tstop / tstart , 1.0 / (num - 1 ))
394
+ result = torch .logspace (
395
+ torch .log (tstart ) / torch .log (base ),
396
+ torch .log (tstop ) / torch .log (base ),
397
+ num ,
398
+ base = base ,
399
+ )
400
+ return result
401
+
402
+
403
+
404
+ def arange (start = None , stop = None , step = 1 , dtype = None ):
405
+ if step == 0 :
406
+ raise ZeroDivisionError
407
+ if stop is None and start is None :
408
+ raise TypeError
409
+ if stop is None :
410
+ # XXX: this breaks if start is passed as a kwarg:
411
+ # arange(start=4) should raise (no stop) but doesn't
412
+ start , stop = 0 , start
413
+ if start is None :
414
+ start = 0
415
+
416
+ if dtype is None :
417
+ dt_list = [_util ._coerce_to_tensor (x ).dtype for x in (start , stop , step )]
418
+ dtype = _dtypes_impl .default_int_dtype
419
+ dt_list .append (dtype )
420
+ dtype = _dtypes_impl .result_type_impl (dt_list )
421
+
422
+ try :
423
+ return torch .arange (start , stop , step , dtype = dtype )
424
+ except RuntimeError :
425
+ raise ValueError ("Maximum allowed size exceeded" )
0 commit comments