Skip to content

Commit 71554d7

Browse files
committed
Rudimentary dtype inferrence in full()
1 parent a64b627 commit 71554d7

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

torch_np/_wrapper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,19 @@ def full(shape, fill_value, dtype=None, order="C", *, like=None):
239239
raise NotImplementedError
240240
if isinstance(fill_value, ndarray):
241241
fill_value = fill_value.get()
242-
torch_dtype = _dtypes.torch_dtype_from(dtype)
242+
if dtype is None:
243+
if isinstance(fill_value, bool):
244+
torch_dtype = torch.bool
245+
elif isinstance(fill_value, int):
246+
torch_dtype = torch.int64
247+
elif isinstance(fill_value, float):
248+
torch_dtype = torch.float64
249+
elif isinstance(fill_value, complex):
250+
torch_dtype = torch.complex128
251+
else:
252+
torch_dtype = _dtypes.torch_dtype_from(dtype)
253+
else:
254+
torch_dtype = _dtypes.torch_dtype_from(dtype)
243255
return asarray(torch.full(shape, fill_value, dtype=torch_dtype))
244256

245257

torch_np/tests/test_stuff.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,7 @@ def test_full(shape, data):
6666
else:
6767
assert out.dtype == kw["dtype"]
6868
assert out.shape == shape
69-
assert (out == fill_value).all()
69+
if fill_value is float("nan"):
70+
assert np.isnan(out).all()
71+
else:
72+
assert (out == fill_value).all()

0 commit comments

Comments
 (0)