Skip to content

Commit f2ff966

Browse files
committed
MAINT: split arange/geomspace + full/empty etc
1 parent af97984 commit f2ff966

File tree

2 files changed

+114
-63
lines changed

2 files changed

+114
-63
lines changed

torch_np/_detail/implementations.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def split_helper(tensor, indices_or_sections, axis, strict=False):
138138
if isinstance(indices_or_sections, int):
139139
return split_helper_int(tensor, indices_or_sections, axis, strict)
140140
elif isinstance(indices_or_sections, (list, tuple)):
141-
return split_helper_list(tensor, list(indices_or_sections), axis, strict)
141+
# NB: drop split=..., it only applies to split_helper_int
142+
return split_helper_list(tensor, list(indices_or_sections), axis)
142143
else:
143144
raise TypeError("split_helper: ", type(indices_or_sections))
144145

@@ -170,7 +171,7 @@ def split_helper_int(tensor, indices_or_sections, axis, strict=False):
170171
return result
171172

172173

173-
def split_helper_list(tensor, indices_or_sections, axis, strict=False):
174+
def split_helper_list(tensor, indices_or_sections, axis):
174175
if not isinstance(indices_or_sections, list):
175176
raise NotImplementedError("split: indices_or_sections: list")
176177
# numpy expectes indices, while torch expects lengths of sections
@@ -256,7 +257,9 @@ def stack(tensors, axis=0, out=None, *, dtype=None, casting="same_kind"):
256257

257258
sl = (slice(None),) * axis + (None,)
258259
expanded_tensors = [tensor[sl] for tensor in tensors]
259-
result = concatenate(expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting)
260+
result = concatenate(
261+
expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting
262+
)
260263

261264
return result
262265

@@ -374,10 +377,99 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"):
374377
return output
375378

376379

377-
378380
def bincount(x_tensor, /, weights_tensor=None, minlength=0):
379381
int_dtype = _dtypes_impl.default_int_dtype
380382
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
381383

382384
result = torch.bincount(x_tensor, weights_tensor, minlength)
383385
return result
386+
387+
388+
# ### linspace, geomspace, logspace and arange ###
389+
390+
391+
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
392+
if axis != 0 or not endpoint:
393+
raise NotImplementedError
394+
tstart, tstop = torch.as_tensor([start, stop])
395+
base = torch.pow(tstop / tstart, 1.0 / (num - 1))
396+
result = torch.logspace(
397+
torch.log(tstart) / torch.log(base),
398+
torch.log(tstop) / torch.log(base),
399+
num,
400+
base=base,
401+
)
402+
return result
403+
404+
405+
def arange(start=None, stop=None, step=1, dtype=None):
406+
if step == 0:
407+
raise ZeroDivisionError
408+
if stop is None and start is None:
409+
raise TypeError
410+
if stop is None:
411+
# XXX: this breaks if start is passed as a kwarg:
412+
# arange(start=4) should raise (no stop) but doesn't
413+
start, stop = 0, start
414+
if start is None:
415+
start = 0
416+
417+
if dtype is None:
418+
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
419+
dtype = _dtypes_impl.default_int_dtype
420+
dt_list.append(dtype)
421+
dtype = _dtypes_impl.result_type_impl(dt_list)
422+
423+
try:
424+
return torch.arange(start, stop, step, dtype=dtype)
425+
except RuntimeError:
426+
raise ValueError("Maximum allowed size exceeded")
427+
428+
429+
# ### empty/full et al ###
430+
431+
432+
def eye(N, M=None, k=0, dtype=float):
433+
if M is None:
434+
M = N
435+
z = torch.zeros(N, M, dtype=dtype)
436+
z.diagonal(k).fill_(1)
437+
return z
438+
439+
440+
def zeros_like(a, dtype=None, shape=None):
441+
result = torch.zeros_like(a, dtype=dtype)
442+
if shape is not None:
443+
result = result.reshape(shape)
444+
return result
445+
446+
447+
def ones_like(a, dtype=None, shape=None):
448+
result = torch.ones_like(a, dtype=dtype)
449+
if shape is not None:
450+
result = result.reshape(shape)
451+
return result
452+
453+
454+
def full_like(a, fill_value, dtype=None, shape=None):
455+
# XXX: fill_value broadcasts
456+
result = torch.full_like(a, fill_value, dtype=dtype)
457+
if shape is not None:
458+
result = result.reshape(shape)
459+
return result
460+
461+
462+
def empty_like(prototype, dtype=None, shape=None):
463+
result = torch.empty_like(prototype, dtype=dtype)
464+
if shape is not None:
465+
result = result.reshape(shape)
466+
return result
467+
468+
469+
def full(shape, fill_value, dtype=None):
470+
if dtype is None:
471+
dtype = fill_value.dtype
472+
if not isinstance(shape, (tuple, list)):
473+
shape = (shape,)
474+
result = torch.full(shape, fill_value, dtype=dtype)
475+
return result

torch_np/_wrapper.py

Lines changed: 18 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def dstack(tup, *, dtype=None, casting="same_kind"):
151151
# but {h,v}stack do. Hence add them here for consistency.
152152
tensors = _helpers.to_tensors(*tup)
153153
result = _impl.dstack(tensors, dtype=dtype, casting=casting)
154-
return asarray(result)
154+
return asarray(result)
155155

156156

157157
@_decorators.dtype_to_torch
@@ -257,49 +257,27 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
257257
return asarray(torch.linspace(start, stop, num, dtype=dtype))
258258

259259

260+
@_decorators.dtype_to_torch
260261
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
261262
if axis != 0 or not endpoint:
262263
raise NotImplementedError
263-
tstart, tstop = torch.as_tensor([start, stop])
264-
base = torch.pow(tstop / tstart, 1.0 / (num - 1))
265-
result = torch.logspace(
266-
torch.log(tstart) / torch.log(base),
267-
torch.log(tstop) / torch.log(base),
268-
num,
269-
base=base,
270-
)
264+
result = _impl.geomspace(start, stop, num, endpoint, dtype, axis)
271265
return asarray(result)
272266

273267

268+
@_decorators.dtype_to_torch
274269
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
275270
if axis != 0 or not endpoint:
276271
raise NotImplementedError
277272
return asarray(torch.logspace(start, stop, num, base=base, dtype=dtype))
278273

279274

275+
@_decorators.dtype_to_torch
280276
def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
281277
_util.subok_not_ok(like)
282-
if step == 0:
283-
raise ZeroDivisionError
284-
if stop is None and start is None:
285-
raise TypeError
286-
if stop is None:
287-
# XXX: this breaks if start is passed as a kwarg:
288-
# arange(start=4) should raise (no stop) but doesn't
289-
start, stop = 0, start
290-
if start is None:
291-
start = 0
292-
293-
if dtype is None:
294-
dtype = _dtypes.default_int_type()
295-
dtype = result_type(start, stop, step, dtype)
296-
torch_dtype = _dtypes.torch_dtype_from(dtype)
297278
start, stop, step = _helpers.ndarrays_to_tensors(start, stop, step)
298-
299-
try:
300-
return asarray(torch.arange(start, stop, step, dtype=torch_dtype))
301-
except RuntimeError:
302-
raise ValueError("Maximum allowed size exceeded")
279+
result = _impl.arange(start, stop, step, dtype=dtype)
280+
return asarray(result)
303281

304282

305283
@_decorators.dtype_to_torch
@@ -316,14 +294,12 @@ def empty(shape, dtype=float, order="C", *, like=None):
316294
# NB: *_like function deliberately deviate from numpy: it has subok=True
317295
# as the default; we set subok=False and raise on anything else.
318296
@asarray_replacer()
297+
@_decorators.dtype_to_torch
319298
def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
320299
_util.subok_not_ok(subok=subok)
321300
if order != "K":
322301
raise NotImplementedError
323-
torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype)
324-
result = torch.empty_like(prototype, dtype=torch_dtype)
325-
if shape is not None:
326-
result = result.reshape(shape)
302+
result = _impl.empty_like(prototype, dtype=dtype, shape=shape)
327303
return result
328304

329305

@@ -332,28 +308,18 @@ def full(shape, fill_value, dtype=None, order="C", *, like=None):
332308
_util.subok_not_ok(like)
333309
if order != "C":
334310
raise NotImplementedError
335-
336311
fill_value = asarray(fill_value).get()
337-
if dtype is None:
338-
dtype = fill_value.dtype
339-
340-
if not isinstance(shape, (tuple, list)):
341-
shape = (shape,)
342-
343-
result = torch.full(shape, fill_value, dtype=dtype)
344-
312+
result = _impl.full(shape, fill_value, dtype=dtype)
345313
return asarray(result)
346314

347315

348316
@asarray_replacer()
317+
@_decorators.dtype_to_torch
349318
def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None):
350319
_util.subok_not_ok(subok=subok)
351320
if order != "K":
352321
raise NotImplementedError
353-
torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype)
354-
result = torch.full_like(a, fill_value, dtype=torch_dtype)
355-
if shape is not None:
356-
result = result.reshape(shape)
322+
result = _impl.full_like(a, fill_value, dtype=dtype, shape=shape)
357323
return result
358324

359325

@@ -369,14 +335,12 @@ def ones(shape, dtype=None, order="C", *, like=None):
369335

370336

371337
@asarray_replacer()
338+
@_decorators.dtype_to_torch
372339
def ones_like(a, dtype=None, order="K", subok=False, shape=None):
373340
_util.subok_not_ok(subok=subok)
374341
if order != "K":
375342
raise NotImplementedError
376-
torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype)
377-
result = torch.ones_like(a, dtype=torch_dtype)
378-
if shape is not None:
379-
result = result.reshape(shape)
343+
result = _impl.ones_like(a, dtype=dtype, shape=shape)
380344
return result
381345

382346

@@ -392,14 +356,12 @@ def zeros(shape, dtype=None, order="C", *, like=None):
392356

393357

394358
@asarray_replacer()
359+
@_decorators.dtype_to_torch
395360
def zeros_like(a, dtype=None, order="K", subok=False, shape=None):
396361
_util.subok_not_ok(subok=subok)
397362
if order != "K":
398363
raise NotImplementedError
399-
torch_dtype = None if dtype is None else _dtypes.torch_dtype_from(dtype)
400-
result = torch.zeros_like(a, dtype=torch_dtype)
401-
if shape is not None:
402-
result = result.reshape(shape)
364+
result = _impl.zeros_like(a, dtype=dtype, shape=shape)
403365
return result
404366

405367

@@ -408,11 +370,8 @@ def eye(N, M=None, k=0, dtype=float, order="C", *, like=None):
408370
_util.subok_not_ok(like)
409371
if order != "C":
410372
raise NotImplementedError
411-
if M is None:
412-
M = N
413-
z = torch.zeros(N, M, dtype=dtype)
414-
z.diagonal(k).fill_(1)
415-
return asarray(z)
373+
result = _impl.eye(N, M, k, dtype)
374+
return asarray(result)
416375

417376

418377
def identity(n, dtype=None, *, like=None):

0 commit comments

Comments
 (0)