Skip to content

Commit 8f1b501

Browse files
committed
MAINT: rebase, pick up changes from main
1 parent 409802b commit 8f1b501

File tree

2 files changed

+9
-31
lines changed

2 files changed

+9
-31
lines changed

torch_np/_ndarray.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -493,26 +493,6 @@ def wrapped(x, *args, **kwds):
493493
###### dtype routines
494494

495495

496-
def can_cast(from_, to, casting="safe"):
497-
from_ = _extract_dtype(from_)
498-
to_ = extract_dtype(to_)
499-
500-
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
501-
502-
'''
503-
# XXX: merge with _dtypes.can_cast. The Q is who converts from ndarray, if needed.
504-
try:
505-
from_dtype = asarray(from_).dtype
506-
except (TypeError, RuntimeError):
507-
# not an array_like; try convering to a dtype
508-
from_dtype = _dtypes.dtype(from_)
509-
510-
try:
511-
to_dtype = asarray(to).dtype
512-
except (TypeError, RuntimeError):
513-
to_dtype = _dtypes.dtype(to)
514-
'''
515-
516496
def _extract_dtype(entry):
517497
try:
518498
dty = _dtypes.dtype(entry)
@@ -521,6 +501,13 @@ def _extract_dtype(entry):
521501
return dty
522502

523503

504+
def can_cast(from_, to, casting="safe"):
505+
from_ = _extract_dtype(from_)
506+
to_ = _extract_dtype(to)
507+
508+
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
509+
510+
524511
def result_type(*arrays_and_dtypes):
525512
dtypes = []
526513

torch_np/_wrapper.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,9 @@ def empty(shape, dtype=float, order="C", *, like=None):
287287
_util.subok_not_ok(like)
288288
if order != "C":
289289
raise NotImplementedError
290-
291290
if dtype is None:
292-
from ._detail._scalar_types import default_float_type
293-
294-
dtype = default_float_type.torch_dtype
295-
291+
dtype = _dtypes_impl.default_float_dtype
296292
result = torch.empty(shape, dtype=dtype)
297-
298293
return asarray(result)
299294

300295

@@ -348,12 +343,8 @@ def ones(shape, dtype=None, order="C", *, like=None):
348343
if order != "C":
349344
raise NotImplementedError
350345
if dtype is None:
351-
from ._detail._scalar_types import default_float_type
352-
353-
dtype = default_float_type.torch_dtype
354-
346+
dtype = _dtypes_impl.default_float_dtype
355347
result = torch.ones(shape, dtype=dtype)
356-
357348
return asarray(result)
358349

359350

0 commit comments

Comments
 (0)