Skip to content

Commit 6799cc6

Browse files
committed
test_isdtype
1 parent a533680 commit 6799cc6

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

array_api_tests/_array_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __repr__(self):
6262
]
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
65+
_funcs += ["isdtype"] # TODO: bump spec and update array-api-tests to new spec layout
6566
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
6667

6768
for attr in _top_level_attrs:

array_api_tests/dtype_helpers.py

+12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"bool_and_all_int_dtypes",
2424
"dtype_to_name",
2525
"dtype_to_scalars",
26+
"kind_to_dtypes",
2627
"is_int_dtype",
2728
"is_float_dtype",
2829
"get_scalar_type",
@@ -125,6 +126,17 @@ def __repr__(self):
125126
)
126127

127128

129+
kind_to_dtypes = {
130+
"bool": [xp.bool],
131+
"signed integer": int_dtypes,
132+
"unsigned integer": uint_dtypes,
133+
"integral": all_int_dtypes,
134+
"real floating": float_dtypes,
135+
"complex floating": complex_dtypes,
136+
"numeric": numeric_dtypes,
137+
}
138+
139+
128140
def is_int_dtype(dtype):
129141
return dtype in all_int_dtypes
130142

array_api_tests/test_data_type_functions.py

+27
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,33 @@ def test_iinfo(dtype):
174174
# TODO: test values
175175

176176

177+
def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:
178+
return xps.scalar_dtypes() | st.sampled_from(list(dh.kind_to_dtypes.keys()))
179+
180+
181+
@pytest.mark.min_version("2022.12")
182+
@given(
183+
dtype=xps.scalar_dtypes(),
184+
kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple),
185+
)
186+
def test_isdtype(dtype, kind):
187+
out = xp.isdtype(dtype, kind)
188+
189+
assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]"
190+
_kinds = kind if isinstance(kind, tuple) else (kind,)
191+
expected = False
192+
for _kind in _kinds:
193+
if isinstance(_kind, str):
194+
if dtype in dh.kind_to_dtypes[_kind]:
195+
expected = True
196+
break
197+
else:
198+
if dtype == _kind:
199+
expected = True
200+
break
201+
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
202+
203+
177204
@given(hh.mutually_promotable_dtypes(None))
178205
def test_result_type(dtypes):
179206
out = xp.result_type(*dtypes)

0 commit comments

Comments
 (0)