From 5f5fb0817a8ec31b4b845413ef29e45db84d9059 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 15 Mar 2023 11:36:23 +0000 Subject: [PATCH] `test_isdtype` --- array_api_tests/_array_module.py | 2 +- array_api_tests/dtype_helpers.py | 12 +++++++++ array_api_tests/test_data_type_functions.py | 27 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index 4ab584b7..aeaa4610 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -62,7 +62,7 @@ def __repr__(self): ] _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] -_funcs += ["take"] # TODO: bump spec and update array-api-tests to new spec layout +_funcs += ["take", "isdtype"] # TODO: bump spec and update array-api-tests to new spec layout _top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS for attr in _top_level_attrs: diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 8fb8fa7e..c6e90f75 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -24,6 +24,7 @@ "bool_and_all_int_dtypes", "dtype_to_name", "dtype_to_scalars", + "kind_to_dtypes", "is_int_dtype", "is_float_dtype", "get_scalar_type", @@ -139,6 +140,17 @@ def _filter_stubs(*args): ) +kind_to_dtypes = { + "bool": [xp.bool], + "signed integer": int_dtypes, + "unsigned integer": uint_dtypes, + "integral": all_int_dtypes, + "real floating": float_dtypes, + "complex floating": complex_dtypes, + "numeric": numeric_dtypes, +} + + def is_int_dtype(dtype): return dtype in all_int_dtypes diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 6bf6ed7a..ec0e2a2e 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -177,6 +177,33 @@ def test_iinfo(dtype): # TODO: test values +def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]: + return xps.scalar_dtypes() | st.sampled_from(list(dh.kind_to_dtypes.keys())) + + +@pytest.mark.min_version("2022.12") +@given( + dtype=xps.scalar_dtypes(), + kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple), +) +def test_isdtype(dtype, kind): + out = xp.isdtype(dtype, kind) + + assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]" + _kinds = kind if isinstance(kind, tuple) else (kind,) + expected = False + for _kind in _kinds: + if isinstance(_kind, str): + if dtype in dh.kind_to_dtypes[_kind]: + expected = True + break + else: + if dtype == _kind: + expected = True + break + assert out == expected, f"{out=}, but should be {expected} [isdtype()]" + + @given(hh.mutually_promotable_dtypes(None)) def test_result_type(dtypes): out = xp.result_type(*dtypes)