Skip to content

Commit 7992bcc

Browse files
committed
BUG: zeros dtype defaults to float64
1 parent d1e7639 commit 7992bcc

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

torch_np/_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def to_tensors_or_none(*inputs):
9494

9595

9696
def _outer(x, y):
97+
from ._ndarray import asarray
98+
9799
x_tensor, y_tensor = to_tensors(x, y)
98100
result = torch.outer(x_tensor, y_tensor)
99101
return asarray(result)

torch_np/_wrapper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99

10-
from . import _dtypes, _helpers, _decorators # isort : skip
1110
from ._detail import _flips, _reductions, _util
1211
from ._detail import implementations as _impl
1312
from ._ndarray import (
@@ -20,6 +19,8 @@
2019
result_type,
2120
)
2221

22+
from . import _decorators, _dtypes, _helpers # isort: skip
23+
2324
# Things to decide on (punt for now)
2425
#
2526
# 1. Q: What are the return types of wrapper functions: plain torch.Tensors or
@@ -272,12 +273,15 @@ def ones_like(a, dtype=None, order="K", subok=False, shape=None):
272273
return result
273274

274275

275-
# XXX: dtype=float
276276
@_decorators.dtype_to_torch
277-
def zeros(shape, dtype=float, order="C", *, like=None):
277+
def zeros(shape, dtype=None, order="C", *, like=None):
278278
_util.subok_not_ok(like)
279279
if order != "C":
280280
raise NotImplementedError
281+
if dtype is None:
282+
from ._detail._scalar_types import default_float_type
283+
284+
dtype = default_float_type.torch_dtype
281285
return asarray(torch.zeros(shape, dtype=dtype))
282286

283287

0 commit comments

Comments
 (0)