Skip to content

Commit ef54d85

Browse files
committed
test_isdtype
1 parent a533680 commit ef54d85

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

array_api_tests/_array_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __repr__(self):
6262
]
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
65+
_funcs += ["isdtype"] # TODO: bump spec and update array-api-tests to new spec layout
6566
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
6667

6768
for attr in _top_level_attrs:

array_api_tests/dtype_helpers.py

+12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"bool_and_all_int_dtypes",
2424
"dtype_to_name",
2525
"dtype_to_scalars",
26+
"kind_to_dtypes",
2627
"is_int_dtype",
2728
"is_float_dtype",
2829
"get_scalar_type",
@@ -125,6 +126,17 @@ def __repr__(self):
125126
)
126127

127128

129+
kind_to_dtypes = {
130+
"bool": [xp.bool],
131+
"signed integer": int_dtypes,
132+
"unsigned integer": uint_dtypes,
133+
"integral": all_int_dtypes,
134+
"real floating": float_dtypes,
135+
"complex floating": complex_dtypes,
136+
"numeric": numeric_dtypes,
137+
}
138+
139+
128140
def is_int_dtype(dtype):
129141
return dtype in all_int_dtypes
130142

array_api_tests/test_data_type_functions.py

+41
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from hypothesis import strategies as st
77

88
from . import _array_module as xp
9+
from . import api_version
910
from . import dtype_helpers as dh
1011
from . import hypothesis_helpers as hh
1112
from . import pytest_helpers as ph
@@ -174,6 +175,46 @@ def test_iinfo(dtype):
174175
# TODO: test values
175176

176177

178+
if api_version >= "2022.12":
179+
180+
def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:
181+
return st.one_of(
182+
xps.scalar_dtypes(),
183+
st.sampled_from(
184+
[
185+
"bool",
186+
"signed integer",
187+
"unsigned integer",
188+
"integral",
189+
"real floating",
190+
"complex floating",
191+
"numeric",
192+
]
193+
),
194+
)
195+
196+
@given(
197+
dtype=xps.scalar_dtypes(),
198+
kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple),
199+
)
200+
def test_isdtype(dtype, kind):
201+
out = xp.isdtype(dtype, kind)
202+
203+
assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]"
204+
_kinds = kind if isinstance(kind, tuple) else (kind,)
205+
expected = False
206+
for _kind in _kinds:
207+
if isinstance(_kind, str):
208+
if dtype in dh.kind_to_dtypes[_kind]:
209+
expected = True
210+
break
211+
else:
212+
if dtype == _kind:
213+
expected = True
214+
break
215+
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
216+
217+
177218
@given(hh.mutually_promotable_dtypes(None))
178219
def test_result_type(dtypes):
179220
out = xp.result_type(*dtypes)

0 commit comments

Comments
 (0)