Skip to content

Commit 81f2adb

Browse files
committed
BUG: zeros dtype defaults to float64
1 parent d1e7639 commit 81f2adb

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torch_np/_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,13 @@ def ones_like(a, dtype=None, order="K", subok=False, shape=None):
274274

275275
# XXX: dtype=float
276276
@_decorators.dtype_to_torch
277-
def zeros(shape, dtype=float, order="C", *, like=None):
277+
def zeros(shape, dtype=None, order="C", *, like=None):
278278
_util.subok_not_ok(like)
279279
if order != "C":
280280
raise NotImplementedError
281+
if dtype is None:
282+
from ._detail._scalar_types import default_float_type
283+
dtype = default_float_type.torch_dtype
281284
return asarray(torch.zeros(shape, dtype=dtype))
282285

283286

0 commit comments

Comments
 (0)