Skip to content

Commit 89a8d54

Browse files
committed
MAINT: rebase, pick up changes from main
1 parent efc8325 commit 89a8d54

File tree

2 files changed

+9
-31
lines changed

2 files changed

+9
-31
lines changed

torch_np/_ndarray.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -493,26 +493,6 @@ def wrapped(x, *args, **kwds):
493493
###### dtype routines
494494

495495

496-
def can_cast(from_, to, casting="safe"):
497-
from_ = _extract_dtype(from_)
498-
to_ = extract_dtype(to_)
499-
500-
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
501-
502-
'''
503-
# XXX: merge with _dtypes.can_cast. The Q is who converts from ndarray, if needed.
504-
try:
505-
from_dtype = asarray(from_).dtype
506-
except (TypeError, RuntimeError):
507-
# not an array_like; try convering to a dtype
508-
from_dtype = _dtypes.dtype(from_)
509-
510-
try:
511-
to_dtype = asarray(to).dtype
512-
except (TypeError, RuntimeError):
513-
to_dtype = _dtypes.dtype(to)
514-
'''
515-
516496
def _extract_dtype(entry):
517497
try:
518498
dty = _dtypes.dtype(entry)
@@ -521,6 +501,13 @@ def _extract_dtype(entry):
521501
return dty
522502

523503

504+
def can_cast(from_, to, casting="safe"):
505+
from_ = _extract_dtype(from_)
506+
to_ = _extract_dtype(to)
507+
508+
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
509+
510+
524511
def result_type(*arrays_and_dtypes):
525512
dtypes = []
526513

torch_np/_wrapper.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,9 @@ def empty(shape, dtype=float, order="C", *, like=None):
293293
_util.subok_not_ok(like)
294294
if order != "C":
295295
raise NotImplementedError
296-
297296
if dtype is None:
298-
from ._detail._scalar_types import default_float_type
299-
300-
dtype = default_float_type.torch_dtype
301-
297+
dtype = _dtypes_impl.default_float_dtype
302298
result = torch.empty(shape, dtype=dtype)
303-
304299
return asarray(result)
305300

306301

@@ -354,12 +349,8 @@ def ones(shape, dtype=None, order="C", *, like=None):
354349
if order != "C":
355350
raise NotImplementedError
356351
if dtype is None:
357-
from ._detail._scalar_types import default_float_type
358-
359-
dtype = default_float_type.torch_dtype
360-
352+
dtype = _dtypes_impl.default_float_dtype
361353
result = torch.ones(shape, dtype=dtype)
362-
363354
return asarray(result)
364355

365356

0 commit comments

Comments
 (0)