Skip to content

Commit a87462d

Browse files
committed
More thorough test_full()
Also implements `DType.kind` to raise `NotImplementedError`
1 parent 47e8adf commit a87462d

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

torch_np/_dtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def name(self):
5555
def type(self):
5656
return self._scalar_type
5757

58+
@property
59+
def kind(self):
60+
raise NotImplementedError
61+
5862
@property
5963
def typecode(self):
6064
return self._scalar_type.typecode

torch_np/tests/test_xps.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,40 @@
2525
xps = make_strategies_namespace(np, api_version="2022.12")
2626

2727

28+
default_dtypes = [np.bool, np.int64, np.float64, np.complex128]
29+
kind_to_strat = {
30+
"b": xps.boolean_dtypes(),
31+
"i": xps.integer_dtypes(),
32+
"u": xps.unsigned_integer_dtypes(sizes=8),
33+
"f": xps.floating_dtypes(),
34+
"c": xps.complex_dtypes(),
35+
}
36+
scalar_dtype_strat = st.one_of(kind_to_strat.values()).map(np.dtype)
37+
38+
2839
@given(shape=xps.array_shapes(), data=st.data())
2940
def test_full(shape, data):
3041
if data.draw(st.booleans(), label="pass kwargs?"):
31-
kw = {}
32-
else:
33-
dtype = data.draw(st.none() | xps.scalar_dtypes(), label="dtype")
42+
dtype = data.draw(st.none() | scalar_dtype_strat, label="dtype")
3443
kw = {"dtype": dtype}
35-
_dtype = kw.get("dtype", None) or data.draw(
36-
st.sampled_from([np.bool, np.int64, np.float64, np.complex128]), label="_dtype"
37-
)
44+
else:
45+
kw = {}
46+
_dtype = kw.get("dtype", None) or data.draw(scalar_dtype_strat, label="_dtype")
3847
values_strat = xps.from_dtype(_dtype)
39-
fill_value = data.draw(
40-
values_strat | values_strat.map(lambda v: np.asarray(v, dtype=_dtype)),
41-
label="fill_value",
42-
)
43-
out = np.full(shape, fill_value, **kw)
44-
if kw.get("dtype", None) is None and not isinstance(fill_value, np.ndarray):
45-
if isinstance(fill_value, bool):
46-
assert out.dtype == np.bool
47-
elif isinstance(fill_value, int):
48-
assert out.dtype == np.int64
49-
elif isinstance(fill_value, float):
50-
assert out.dtype == np.float64
48+
if _dtype not in default_dtypes or data.draw(
49+
st.booleans(), label="fill_value is array?"
50+
):
51+
if specified_dtype := kw.get("dtype", None):
52+
kind = specified_dtype.name[0]
53+
values_dtypes_strat = kind_to_strat[kind]
5154
else:
52-
assert isinstance(fill_value, complex) # sanity check
53-
assert out.dtype == np.complex128
54-
else:
55-
assert out.dtype == _dtype
55+
values_dtypes_strat = st.just(_dtype)
56+
values_strat = values_dtypes_strat.flatmap(
57+
lambda d: values_strat.map(lambda v: np.asarray(v, dtype=d))
58+
)
59+
fill_value = data.draw(values_strat, label="fill_value")
60+
out = np.full(shape, fill_value, **kw)
61+
assert out.dtype == _dtype
5662
assert out.shape == shape
5763
if cmath.isnan(fill_value):
5864
assert np.isnan(out).all()

0 commit comments

Comments
 (0)