Skip to content

Commit c063e33

Browse files
authored
Merge pull request #51 from Quansight-Labs/infer-full-dtype-followup
Make ones/zeros/empty/full dtype handling more uniform
2 parents 7e4ed48 + 1b41e14 commit c063e33

File tree

6 files changed

+81
-249
lines changed

6 files changed

+81
-249
lines changed

torch_np/_dtypes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class bool_(generic):
197197
_typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]}
198198
_torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]}
199199

200+
200201
_aliases = {
201202
"u1": uint8,
202203
"i1": int8,
@@ -285,6 +286,11 @@ def name(self):
285286
def type(self):
286287
return self._scalar_type
287288

289+
@property
290+
def kind(self):
291+
# https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
292+
return _torch_dtypes[self.torch_dtype].name[0]
293+
288294
@property
289295
def typecode(self):
290296
return self._scalar_type.typecode

torch_np/_ndarray.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,6 @@ def wrapped(x, *args, **kwds):
493493
###### dtype routines
494494

495495

496-
def can_cast(from_, to, casting="safe"):
497-
from_ = from_.dtype if isinstance(from_, ndarray) else _dtypes.dtype(from_)
498-
to_ = to.dtype if isinstance(to, ndarray) else _dtypes.dtype(to)
499-
500-
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
501-
502-
503496
def _extract_dtype(entry):
504497
try:
505498
dty = _dtypes.dtype(entry)
@@ -508,6 +501,13 @@ def _extract_dtype(entry):
508501
return dty
509502

510503

504+
def can_cast(from_, to, casting="safe"):
505+
from_ = _extract_dtype(from_)
506+
to_ = _extract_dtype(to)
507+
508+
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
509+
510+
511511
def result_type(*arrays_and_dtypes):
512512
dtypes = []
513513

torch_np/_wrapper.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,15 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
288288
raise ValueError("Maximum allowed size exceeded")
289289

290290

291+
@_decorators.dtype_to_torch
291292
def empty(shape, dtype=float, order="C", *, like=None):
292293
_util.subok_not_ok(like)
293294
if order != "C":
294295
raise NotImplementedError
295-
torch_dtype = _dtypes.torch_dtype_from(dtype)
296-
return asarray(torch.empty(shape, dtype=torch_dtype))
296+
if dtype is None:
297+
dtype = _dtypes_impl.default_float_dtype
298+
result = torch.empty(shape, dtype=dtype)
299+
return asarray(result)
297300

298301

299302
# NB: *_like function deliberately deviate from numpy: it has subok=True
@@ -310,17 +313,22 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
310313
return result
311314

312315

316+
@_decorators.dtype_to_torch
313317
def full(shape, fill_value, dtype=None, order="C", *, like=None):
314318
_util.subok_not_ok(like)
315319
if order != "C":
316320
raise NotImplementedError
317-
if isinstance(fill_value, ndarray):
318-
fill_value = fill_value.get()
321+
322+
fill_value = asarray(fill_value).get()
319323
if dtype is None:
320-
torch_dtype = asarray(fill_value).get().dtype
321-
else:
322-
torch_dtype = _dtypes.torch_dtype_from(dtype)
323-
return asarray(torch.full(shape, fill_value, dtype=torch_dtype))
324+
dtype = fill_value.dtype
325+
326+
if not isinstance(shape, (tuple, list)):
327+
shape = (shape,)
328+
329+
result = torch.full(shape, fill_value, dtype=dtype)
330+
331+
return asarray(result)
324332

325333

326334
@asarray_replacer()
@@ -335,12 +343,15 @@ def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None):
335343
return result
336344

337345

346+
@_decorators.dtype_to_torch
338347
def ones(shape, dtype=None, order="C", *, like=None):
339348
_util.subok_not_ok(like)
340349
if order != "C":
341350
raise NotImplementedError
342-
torch_dtype = _dtypes.torch_dtype_from(dtype)
343-
return asarray(torch.ones(shape, dtype=torch_dtype))
351+
if dtype is None:
352+
dtype = _dtypes_impl.default_float_dtype
353+
result = torch.ones(shape, dtype=dtype)
354+
return asarray(result)
344355

345356

346357
@asarray_replacer()
@@ -362,7 +373,8 @@ def zeros(shape, dtype=None, order="C", *, like=None):
362373
raise NotImplementedError
363374
if dtype is None:
364375
dtype = _dtypes_impl.default_float_dtype
365-
return asarray(torch.zeros(shape, dtype=dtype))
376+
result = torch.zeros(shape, dtype=dtype)
377+
return asarray(result)
366378

367379

368380
@asarray_replacer()

0 commit comments

Comments
 (0)