Skip to content

Commit 707f4a4

Browse files
committed
SQUASH: premiliminary split arange/geomspace
1 parent af97984 commit 707f4a4

File tree

2 files changed

+50
-30
lines changed

2 files changed

+50
-30
lines changed

torch_np/_detail/implementations.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def split_helper(tensor, indices_or_sections, axis, strict=False):
138138
if isinstance(indices_or_sections, int):
139139
return split_helper_int(tensor, indices_or_sections, axis, strict)
140140
elif isinstance(indices_or_sections, (list, tuple)):
141-
return split_helper_list(tensor, list(indices_or_sections), axis, strict)
141+
# NB: drop split=..., it only applies to split_helper_int
142+
return split_helper_list(tensor, list(indices_or_sections), axis)
142143
else:
143144
raise TypeError("split_helper: ", type(indices_or_sections))
144145

@@ -170,7 +171,7 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False):
170171
return result
171172

172173

173-
def split_helper_list(tensor, indices_or_sections, axis, strict=False):
174+
def split_helper_list(tensor, indices_or_sections, axis):
174175
if not isinstance(indices_or_sections, list):
175176
raise NotImplementedError("split: indices_or_sections: list")
176177
# numpy expectes indices, while torch expects lengths of sections
@@ -381,3 +382,44 @@ def bincount(x_tensor, /, weights_tensor=None, minlength=0):
381382

382383
result = torch.bincount(x_tensor, weights_tensor, minlength)
383384
return result
385+
386+
387+
# ### linspace, geomspace, logspace and arange ###
388+
389+
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
390+
if axis != 0 or not endpoint:
391+
raise NotImplementedError
392+
tstart, tstop = torch.as_tensor([start, stop])
393+
base = torch.pow(tstop / tstart, 1.0 / (num - 1))
394+
result = torch.logspace(
395+
torch.log(tstart) / torch.log(base),
396+
torch.log(tstop) / torch.log(base),
397+
num,
398+
base=base,
399+
)
400+
return result
401+
402+
403+
404+
def arange(start=None, stop=None, step=1, dtype=None):
405+
if step == 0:
406+
raise ZeroDivisionError
407+
if stop is None and start is None:
408+
raise TypeError
409+
if stop is None:
410+
# XXX: this breaks if start is passed as a kwarg:
411+
# arange(start=4) should raise (no stop) but doesn't
412+
start, stop = 0, start
413+
if start is None:
414+
start = 0
415+
416+
if dtype is None:
417+
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
418+
dtype = _dtypes_impl.default_int_dtype
419+
dt_list.append(dtype)
420+
dtype = _dtypes_impl.result_type_impl(dt_list)
421+
422+
try:
423+
return torch.arange(start, stop, step, dtype=dtype)
424+
except RuntimeError:
425+
raise ValueError("Maximum allowed size exceeded")

torch_np/_wrapper.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -257,49 +257,27 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
257257
return asarray(torch.linspace(start, stop, num, dtype=dtype))
258258

259259

260+
@_decorators.dtype_to_torch
260261
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
261262
if axis != 0 or not endpoint:
262263
raise NotImplementedError
263-
tstart, tstop = torch.as_tensor([start, stop])
264-
base = torch.pow(tstop / tstart, 1.0 / (num - 1))
265-
result = torch.logspace(
266-
torch.log(tstart) / torch.log(base),
267-
torch.log(tstop) / torch.log(base),
268-
num,
269-
base=base,
270-
)
264+
result = _impl.geomspace(start, stop, num, endpoint, dtype, axis)
271265
return asarray(result)
272266

273267

268+
@_decorators.dtype_to_torch
274269
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
275270
if axis != 0 or not endpoint:
276271
raise NotImplementedError
277272
return asarray(torch.logspace(start, stop, num, base=base, dtype=dtype))
278273

279274

275+
@_decorators.dtype_to_torch
280276
def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
281277
_util.subok_not_ok(like)
282-
if step == 0:
283-
raise ZeroDivisionError
284-
if stop is None and start is None:
285-
raise TypeError
286-
if stop is None:
287-
# XXX: this breaks if start is passed as a kwarg:
288-
# arange(start=4) should raise (no stop) but doesn't
289-
start, stop = 0, start
290-
if start is None:
291-
start = 0
292-
293-
if dtype is None:
294-
dtype = _dtypes.default_int_type()
295-
dtype = result_type(start, stop, step, dtype)
296-
torch_dtype = _dtypes.torch_dtype_from(dtype)
297278
start, stop, step = _helpers.ndarrays_to_tensors(start, stop, step)
298-
299-
try:
300-
return asarray(torch.arange(start, stop, step, dtype=torch_dtype))
301-
except RuntimeError:
302-
raise ValueError("Maximum allowed size exceeded")
279+
result = _impl.arange(start, stop, step, dtype=dtype)
280+
return asarray(result)
303281

304282

305283
@_decorators.dtype_to_torch

0 commit comments

Comments
 (0)