Skip to content

Commit a64b627

Browse files
committed
Test for full()
1 parent 8d3d21c commit a64b627

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

torch_np/tests/test_stuff.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from hypothesis.errors import HypothesisWarning
1010
from hypothesis.extra.array_api import make_strategies_namespace
1111

12-
import torch_np
12+
import torch_np as np
1313

1414
__all__ = ["xps"]
1515

1616
with warnings.catch_warnings():
1717
warnings.filterwarnings("ignore", category=HypothesisWarning)
18-
xps = make_strategies_namespace(torch_np, api_version="2021.12")
18+
np.bool = np.bool_
19+
xps = make_strategies_namespace(np, api_version="2021.12")
1920

2021

2122
def integer_array_indices(shape, result_shape) -> st.SearchStrategy[tuple]:
@@ -40,3 +41,29 @@ def test_integer_indexing(x, data):
4041
idx = data.draw(integer_array_indices(x.shape, result_shape), label="idx")
4142
result = x[idx]
4243
assert result.shape == result_shape
44+
45+
46+
@given(shape=xps.array_shapes(), data=st.data())
47+
def test_full(shape, data):
48+
if data.draw(st.booleans(), label="pass kwargs?"):
49+
kw = {}
50+
else:
51+
dtype = data.draw(st.none() | xps.scalar_dtypes(), label="dtype")
52+
kw = {"dtype": dtype}
53+
_dtype = kw.get("dtype", None) or data.draw(
54+
st.sampled_from([np.bool, np.int64, np.float64]), label="_dtype"
55+
)
56+
fill_value = data.draw(xps.from_dtype(_dtype), label="fill_value")
57+
out = np.full(shape, fill_value, **kw)
58+
if kw.get("dtype", None) is None:
59+
if isinstance(fill_value, bool):
60+
assert out.dtype == np.bool
61+
elif isinstance(fill_value, int):
62+
assert out.dtype == np.int64
63+
else:
64+
assert isinstance(fill_value, float) # sanity check
65+
assert out.dtype == np.float64
66+
else:
67+
assert out.dtype == kw["dtype"]
68+
assert out.shape == shape
69+
assert (out == fill_value).all()

0 commit comments

Comments
 (0)