Skip to content

Commit 2093a12

Browse files
authored
Merge pull request #174 from honno/test-isdtype
`test_isdtype`
2 parents a398866 + 5f5fb08 commit 2093a12

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

array_api_tests/_array_module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __repr__(self):
6262
]
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
65-
_funcs += ["take"] # TODO: bump spec and update array-api-tests to new spec layout
65+
_funcs += ["take", "isdtype"] # TODO: bump spec and update array-api-tests to new spec layout
6666
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
6767

6868
for attr in _top_level_attrs:

array_api_tests/dtype_helpers.py

+12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"bool_and_all_int_dtypes",
2525
"dtype_to_name",
2626
"dtype_to_scalars",
27+
"kind_to_dtypes",
2728
"is_int_dtype",
2829
"is_float_dtype",
2930
"get_scalar_type",
@@ -139,6 +140,17 @@ def _filter_stubs(*args):
139140
)
140141

141142

143+
kind_to_dtypes = {
144+
"bool": [xp.bool],
145+
"signed integer": int_dtypes,
146+
"unsigned integer": uint_dtypes,
147+
"integral": all_int_dtypes,
148+
"real floating": float_dtypes,
149+
"complex floating": complex_dtypes,
150+
"numeric": numeric_dtypes,
151+
}
152+
153+
142154
def is_int_dtype(dtype):
143155
return dtype in all_int_dtypes
144156

array_api_tests/test_data_type_functions.py

+27
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,33 @@ def test_iinfo(dtype):
177177
# TODO: test values
178178

179179

180+
def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:
181+
return xps.scalar_dtypes() | st.sampled_from(list(dh.kind_to_dtypes.keys()))
182+
183+
184+
@pytest.mark.min_version("2022.12")
185+
@given(
186+
dtype=xps.scalar_dtypes(),
187+
kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple),
188+
)
189+
def test_isdtype(dtype, kind):
190+
out = xp.isdtype(dtype, kind)
191+
192+
assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]"
193+
_kinds = kind if isinstance(kind, tuple) else (kind,)
194+
expected = False
195+
for _kind in _kinds:
196+
if isinstance(_kind, str):
197+
if dtype in dh.kind_to_dtypes[_kind]:
198+
expected = True
199+
break
200+
else:
201+
if dtype == _kind:
202+
expected = True
203+
break
204+
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
205+
206+
180207
@given(hh.mutually_promotable_dtypes(None))
181208
def test_result_type(dtypes):
182209
out = xp.result_type(*dtypes)

0 commit comments

Comments
 (0)