Skip to content

Commit bd0ac5e

Browse files
committed
MAINT: empty/full/ones/zeros: dtype handling
1 parent 00f25b1 commit bd0ac5e

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

torch_np/_wrapper.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,20 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
282282
raise ValueError("Maximum allowed size exceeded")
283283

284284

285+
@_decorators.dtype_to_torch
285286
def empty(shape, dtype=float, order="C", *, like=None):
286287
_util.subok_not_ok(like)
287288
if order != "C":
288289
raise NotImplementedError
289-
torch_dtype = _dtypes.torch_dtype_from(dtype)
290-
return asarray(torch.empty(shape, dtype=torch_dtype))
290+
291+
if dtype is None:
292+
from ._detail._scalar_types import default_float_type
293+
294+
dtype = default_float_type.torch_dtype
295+
296+
result = torch.empty(shape, dtype=dtype)
297+
298+
return asarray(result)
291299

292300

293301
# NB: *_like function deliberately deviate from numpy: it has subok=True
@@ -303,15 +311,22 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
303311
result = result.reshape(shape)
304312
return result
305313

314+
306315
@_decorators.dtype_to_torch
307316
def full(shape, fill_value, dtype=None, order="C", *, like=None):
308317
_util.subok_not_ok(like)
309318
if order != "C":
310319
raise NotImplementedError
320+
311321
fill_value = asarray(fill_value).get()
312322
if dtype is None:
313-
dtype = fill_value.dtype
323+
dtype = fill_value.dtype
324+
325+
if not isinstance(shape, (tuple, list)):
326+
shape = (shape,)
327+
314328
result = torch.full(shape, fill_value, dtype=dtype)
329+
315330
return asarray(result)
316331

317332

@@ -327,12 +342,19 @@ def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None):
327342
return result
328343

329344

345+
@_decorators.dtype_to_torch
330346
def ones(shape, dtype=None, order="C", *, like=None):
331347
_util.subok_not_ok(like)
332348
if order != "C":
333349
raise NotImplementedError
334-
torch_dtype = _dtypes.torch_dtype_from(dtype)
335-
return asarray(torch.ones(shape, dtype=torch_dtype))
350+
if dtype is None:
351+
from ._detail._scalar_types import default_float_type
352+
353+
dtype = default_float_type.torch_dtype
354+
355+
result = torch.ones(shape, dtype=dtype)
356+
357+
return asarray(result)
336358

337359

338360
@asarray_replacer()
@@ -354,7 +376,8 @@ def zeros(shape, dtype=None, order="C", *, like=None):
354376
raise NotImplementedError
355377
if dtype is None:
356378
dtype = _dtypes_impl.default_float_dtype
357-
return asarray(torch.zeros(shape, dtype=dtype))
379+
result = torch.zeros(shape, dtype=dtype)
380+
return asarray(result)
358381

359382

360383
@asarray_replacer()

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2666,12 +2666,19 @@ def test_mode(self):
26662666
np.convolve(d, k, mode=None)
26672667

26682668

2669+
class TestDtypePositional:
2670+
2671+
@pytest.mark.xfail(reason='TODO: restore dtypes as positional args')
2672+
def test_dtype_positional(self):
2673+
np.empty((2,), bool)
2674+
2675+
26692676
class TestArgwhere:
26702677

26712678
@pytest.mark.parametrize('nd', [0, 1, 2])
26722679
def test_nd(self, nd):
26732680
# get an nd array with multiple elements in every dimension
2674-
x = np.empty((2,)*nd, bool)
2681+
x = np.empty((2,)*nd, dtype=bool)
26752682

26762683
# none
26772684
x[...] = False

torch_np/tests/numpy_tests/lib/test_shape_base_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def test_kroncompare(self):
719719
for s in shape:
720720
b = randint(0, 10, size=s)
721721
for r in reps:
722-
a = np.ones(r, b.dtype)
722+
a = np.ones(r, dtype=b.dtype) # TODO: restore dtype positional arg
723723
large = tile(b, r)
724724
klarge = kron(a, b)
725725
assert_equal(large, klarge)

0 commit comments

Comments
 (0)