diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index aeaa4610..899a2591 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -62,7 +62,7 @@ def __repr__(self): ] _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] -_funcs += ["take", "isdtype"] # TODO: bump spec and update array-api-tests to new spec layout +_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout _top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS for attr in _top_level_attrs: diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index d9d2362d..56a89f63 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -2,28 +2,33 @@ from collections import defaultdict from collections.abc import Mapping from functools import lru_cache -from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union +from typing import Any, DefaultDict, Dict, List, NamedTuple, Sequence, Tuple, Union from warnings import warn -from . import _array_module as xp from . import api_version -from ._array_module import _UndefinedStub -from ._array_module import mod as _xp +from ._array_module import mod as xp from .stubs import name_to_func from .typing import DataType, ScalarType __all__ = [ + "uint_names", + "int_names", + "all_int_names", + "real_float_names", + "real_names", + "complex_names", + "numeric_names", + "dtype_names", "int_dtypes", "uint_dtypes", "all_int_dtypes", - "float_dtypes", + "real_float_dtypes", "real_dtypes", "numeric_dtypes", "all_dtypes", "all_float_dtypes", "bool_and_all_int_dtypes", "dtype_to_name", - "dtype_to_scalars", "kind_to_dtypes", "is_int_dtype", "is_float_dtype", @@ -90,62 +95,59 @@ def __repr__(self): return f"EqualityMapping({self})" -def _filter_stubs(*args): - for a in args: - if not isinstance(a, _UndefinedStub): - yield a +uint_names = ("uint8", "uint16", "uint32", "uint64") +int_names = ("int8", "int16", "int32", "int64") +all_int_names = uint_names + int_names +real_float_names = ("float32", "float64") +real_names = uint_names + int_names + real_float_names +complex_names = ("complex64", "complex128") +numeric_names = real_names + complex_names +dtype_names = ("bool",) + numeric_names -_uint_names = ("uint8", "uint16", "uint32", "uint64") -_int_names = ("int8", "int16", "int32", "int64") -_float_names = ("float32", "float64") -_real_names = _uint_names + _int_names + _float_names -_complex_names = ("complex64", "complex128") -_numeric_names = _real_names + _complex_names -_dtype_names = ("bool",) + _numeric_names +_name_to_dtype = {} +for name in dtype_names: + try: + dtype = getattr(xp, name) + except AttributeError: + continue + _name_to_dtype[name] = dtype +dtype_to_name = EqualityMapping([(d, n) for n, d in _name_to_dtype.items()]) -uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) -int_dtypes = tuple(getattr(xp, name) for name in _int_names) -float_dtypes = tuple(getattr(xp, name) for name in _float_names) +def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]: + dtypes = [] + for name in names: + try: + dtype = _name_to_dtype[name] + except KeyError: + continue + dtypes.append(dtype) + return tuple(dtypes) + + +uint_dtypes = _make_dtype_tuple_from_names(uint_names) +int_dtypes = _make_dtype_tuple_from_names(int_names) +real_float_dtypes = _make_dtype_tuple_from_names(real_float_names) all_int_dtypes = uint_dtypes + int_dtypes -real_dtypes = all_int_dtypes + float_dtypes -complex_dtypes = tuple(getattr(xp, name) for name in _complex_names) +real_dtypes = all_int_dtypes + real_float_dtypes +complex_dtypes = _make_dtype_tuple_from_names(complex_names) numeric_dtypes = real_dtypes if api_version > "2021.12": numeric_dtypes += complex_dtypes all_dtypes = (xp.bool,) + numeric_dtypes -all_float_dtypes = float_dtypes +all_float_dtypes = real_float_dtypes if api_version > "2021.12": all_float_dtypes += complex_dtypes bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes -_dtype_name_pairs = [] -for name in _dtype_names: - try: - dtype = getattr(_xp, name) - except AttributeError: - continue - _dtype_name_pairs.append((dtype, name)) -dtype_to_name = EqualityMapping(_dtype_name_pairs) - - -dtype_to_scalars = EqualityMapping( - [ - (xp.bool, [bool]), - *[(d, [int]) for d in all_int_dtypes], - *[(d, [int, float]) for d in float_dtypes], - ] -) - - kind_to_dtypes = { "bool": [xp.bool], "signed integer": int_dtypes, "unsigned integer": uint_dtypes, "integral": all_int_dtypes, - "real floating": float_dtypes, + "real floating": real_float_dtypes, "complex floating": complex_dtypes, "numeric": numeric_dtypes, } @@ -162,16 +164,16 @@ def is_float_dtype(dtype): # See https://github.com/numpy/numpy/issues/18434 if dtype is None: return False - valid_dtypes = float_dtypes + valid_dtypes = real_float_dtypes if api_version > "2021.12": valid_dtypes += complex_dtypes return dtype in valid_dtypes def get_scalar_type(dtype: DataType) -> ScalarType: - if is_int_dtype(dtype): + if dtype in all_int_dtypes: return int - elif is_float_dtype(dtype): + elif dtype in real_float_dtypes: return float elif dtype in complex_dtypes: return complex @@ -179,47 +181,59 @@ def get_scalar_type(dtype: DataType) -> ScalarType: return bool +def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping: + dtype_value_pairs = [] + for name, value in mapping.items(): + assert isinstance(name, str) and name in dtype_names # sanity check + try: + dtype = getattr(xp, name) + except AttributeError: + continue + dtype_value_pairs.append((dtype, value)) + return EqualityMapping(dtype_value_pairs) + + class MinMax(NamedTuple): min: Union[int, float] max: Union[int, float] -dtype_ranges = EqualityMapping( - [ - (xp.int8, MinMax(-128, +127)), - (xp.int16, MinMax(-32_768, +32_767)), - (xp.int32, MinMax(-2_147_483_648, +2_147_483_647)), - (xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)), - (xp.uint8, MinMax(0, +255)), - (xp.uint16, MinMax(0, +65_535)), - (xp.uint32, MinMax(0, +4_294_967_295)), - (xp.uint64, MinMax(0, +18_446_744_073_709_551_615)), - (xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)), - (xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)), - ] +dtype_ranges = _make_dtype_mapping_from_names( + { + "int8": MinMax(-128, +127), + "int16": MinMax(-32_768, +32_767), + "int32": MinMax(-2_147_483_648, +2_147_483_647), + "int64": MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), + "uint8": MinMax(0, +255), + "uint16": MinMax(0, +65_535), + "uint32": MinMax(0, +4_294_967_295), + "uint64": MinMax(0, +18_446_744_073_709_551_615), + "float32": MinMax(-3.4028234663852886e38, 3.4028234663852886e38), + "float64": MinMax(-1.7976931348623157e308, 1.7976931348623157e308), + } ) -dtype_nbits = EqualityMapping( - [(d, 8) for d in _filter_stubs(xp.int8, xp.uint8)] - + [(d, 16) for d in _filter_stubs(xp.int16, xp.uint16)] - + [(d, 32) for d in _filter_stubs(xp.int32, xp.uint32, xp.float32)] - + [(d, 64) for d in _filter_stubs(xp.int64, xp.uint64, xp.float64, xp.complex64)] - + [(d, 128) for d in _filter_stubs(xp.complex128)] -) +r_nbits = re.compile(r"[a-z]+([0-9]+)") +_dtype_nbits: Dict[str, int] = {} +for name in numeric_names: + m = r_nbits.fullmatch(name) + assert m is not None # sanity check / for mypy + _dtype_nbits[name] = int(m.group(1)) +dtype_nbits = _make_dtype_mapping_from_names(_dtype_nbits) -dtype_signed = EqualityMapping( - [(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes] +dtype_signed = _make_dtype_mapping_from_names( + {**{name: True for name in int_names}, **{name: False for name in uint_names}} ) -dtype_components = EqualityMapping( - [(xp.complex64, xp.float32), (xp.complex128, xp.float64)] +dtype_components = _make_dtype_mapping_from_names( + {"complex64": xp.float32, "complex128": xp.float64} ) -if isinstance(xp.asarray, _UndefinedStub): +if not hasattr(xp, "asarray"): default_int = xp.int32 default_float = xp.float32 warn( @@ -231,7 +245,7 @@ class MinMax(NamedTuple): if default_int not in int_dtypes: warn(f"inferred default int is {default_int!r}, which is not an int") default_float = xp.asarray(float()).dtype - if default_float not in float_dtypes: + if default_float not in real_float_dtypes: warn(f"inferred default float is {default_float!r}, which is not a float") if api_version > "2021.12": default_complex = xp.asarray(complex()).dtype @@ -243,60 +257,73 @@ class MinMax(NamedTuple): else: default_complex = None if dtype_nbits[default_int] == 32: - default_uint = xp.uint32 + default_uint = getattr(xp, "uint32", None) else: - default_uint = xp.uint64 - + default_uint = getattr(xp, "uint64", None) -_numeric_promotions = [ +_promotion_table: Dict[Tuple[str, str], str] = { + ("bool", "bool"): "bool", # ints - ((xp.int8, xp.int8), xp.int8), - ((xp.int8, xp.int16), xp.int16), - ((xp.int8, xp.int32), xp.int32), - ((xp.int8, xp.int64), xp.int64), - ((xp.int16, xp.int16), xp.int16), - ((xp.int16, xp.int32), xp.int32), - ((xp.int16, xp.int64), xp.int64), - ((xp.int32, xp.int32), xp.int32), - ((xp.int32, xp.int64), xp.int64), - ((xp.int64, xp.int64), xp.int64), + ("int8", "int8"): "int8", + ("int8", "int16"): "int16", + ("int8", "int32"): "int32", + ("int8", "int64"): "int64", + ("int16", "int16"): "int16", + ("int16", "int32"): "int32", + ("int16", "int64"): "int64", + ("int32", "int32"): "int32", + ("int32", "int64"): "int64", + ("int64", "int64"): "int64", # uints - ((xp.uint8, xp.uint8), xp.uint8), - ((xp.uint8, xp.uint16), xp.uint16), - ((xp.uint8, xp.uint32), xp.uint32), - ((xp.uint8, xp.uint64), xp.uint64), - ((xp.uint16, xp.uint16), xp.uint16), - ((xp.uint16, xp.uint32), xp.uint32), - ((xp.uint16, xp.uint64), xp.uint64), - ((xp.uint32, xp.uint32), xp.uint32), - ((xp.uint32, xp.uint64), xp.uint64), - ((xp.uint64, xp.uint64), xp.uint64), + ("uint8", "uint8"): "uint8", + ("uint8", "uint16"): "uint16", + ("uint8", "uint32"): "uint32", + ("uint8", "uint64"): "uint64", + ("uint16", "uint16"): "uint16", + ("uint16", "uint32"): "uint32", + ("uint16", "uint64"): "uint64", + ("uint32", "uint32"): "uint32", + ("uint32", "uint64"): "uint64", + ("uint64", "uint64"): "uint64", # ints and uints (mixed sign) - ((xp.int8, xp.uint8), xp.int16), - ((xp.int8, xp.uint16), xp.int32), - ((xp.int8, xp.uint32), xp.int64), - ((xp.int16, xp.uint8), xp.int16), - ((xp.int16, xp.uint16), xp.int32), - ((xp.int16, xp.uint32), xp.int64), - ((xp.int32, xp.uint8), xp.int32), - ((xp.int32, xp.uint16), xp.int32), - ((xp.int32, xp.uint32), xp.int64), - ((xp.int64, xp.uint8), xp.int64), - ((xp.int64, xp.uint16), xp.int64), - ((xp.int64, xp.uint32), xp.int64), + ("int8", "uint8"): "int16", + ("int8", "uint16"): "int32", + ("int8", "uint32"): "int64", + ("int16", "uint8"): "int16", + ("int16", "uint16"): "int32", + ("int16", "uint32"): "int64", + ("int32", "uint8"): "int32", + ("int32", "uint16"): "int32", + ("int32", "uint32"): "int64", + ("int64", "uint8"): "int64", + ("int64", "uint16"): "int64", + ("int64", "uint32"): "int64", # floats - ((xp.float32, xp.float32), xp.float32), - ((xp.float32, xp.float64), xp.float64), - ((xp.float64, xp.float64), xp.float64), + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", # complex - ((xp.complex64, xp.complex64), xp.complex64), - ((xp.complex64, xp.complex128), xp.complex128), - ((xp.complex128, xp.complex128), xp.complex128), -] -_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions] -_promotion_table = list(set(_numeric_promotions)) -_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool)) -promotion_table = EqualityMapping(_promotion_table) + ("complex64", "complex64"): "complex64", + ("complex64", "complex128"): "complex128", + ("complex128", "complex128"): "complex128", +} +_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()}) +_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = [] +for (in_name1, in_name2), res_name in _promotion_table.items(): + try: + in_dtype1 = getattr(xp, in_name1) + except AttributeError: + continue + try: + in_dtype2 = getattr(xp, in_name2) + except AttributeError: + continue + try: + res_dtype = getattr(xp, res_name) + except AttributeError: + continue + _promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype)) +promotion_table = EqualityMapping(_promotion_table_pairs) def result_type(*dtypes: DataType): @@ -319,12 +346,13 @@ def result_type(*dtypes: DataType): category_to_dtypes = { "boolean": (xp.bool,), "integer": all_int_dtypes, - "floating-point": float_dtypes, + "floating-point": real_float_dtypes, "numeric": numeric_dtypes, "integer or boolean": bool_and_all_int_dtypes, } func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes) for name, func in name_to_func.items(): + assert func.__doc__ is not None # for mypy if m := r_in_dtypes.search(func.__doc__): dtype_category = m.group(1) if dtype_category == "numeric" and r_int_note.search(func.__doc__): @@ -332,7 +360,7 @@ def result_type(*dtypes: DataType): dtypes = category_to_dtypes[dtype_category] func_in_dtypes[name] = dtypes # See https://github.com/data-apis/array-api/pull/413 -func_in_dtypes["expm1"] = float_dtypes +func_in_dtypes["expm1"] = real_float_dtypes func_returns_bool = { @@ -457,11 +485,10 @@ def result_type(*dtypes: DataType): } +# Construct func_in_dtypes and func_returns bool for op, elwise_func in op_to_func.items(): func_in_dtypes[op] = func_in_dtypes[elwise_func] func_returns_bool[op] = func_returns_bool[elwise_func] - - inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): if op == "__matmul__" or func_returns_bool[op]: @@ -470,12 +497,10 @@ def result_type(*dtypes: DataType): inplace_op_to_symbol[iop] = f"{symbol}=" func_in_dtypes[iop] = func_in_dtypes[op] func_returns_bool[iop] = func_returns_bool[op] - - func_in_dtypes["__bool__"] = (xp.bool,) func_in_dtypes["__int__"] = all_int_dtypes func_in_dtypes["__index__"] = all_int_dtypes -func_in_dtypes["__float__"] = float_dtypes +func_in_dtypes["__float__"] = real_float_dtypes func_in_dtypes["from_dlpack"] = numeric_dtypes func_in_dtypes["__dlpack__"] = numeric_dtypes diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 04369214..31f1e153 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -1,4 +1,6 @@ +import re import itertools +from contextlib import contextmanager from functools import reduce from math import sqrt from operator import mul @@ -39,7 +41,7 @@ shared_dtypes = shared(dtypes, key="dtype") shared_floating_dtypes = shared(floating_dtypes, key="dtype") -_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes] +_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes] _sorted_dtypes = [d for category in _dtype_categories for d in category] def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]): @@ -477,3 +479,14 @@ def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]: axes_strats.append(integers(-ndim, ndim - 1)) axes_strats.append(xps.valid_tuple_axes(ndim)) return one_of(axes_strats) + + +@contextmanager +def reject_overflow(): + try: + yield + except Exception as e: + if isinstance(e, OverflowError) or re.search("[Oo]verflow", str(e)): + reject() + else: + raise e diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index 647cc145..b3e5cf3d 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -1,8 +1,10 @@ from math import prod +from typing import Type import pytest from hypothesis import given, settings from hypothesis import strategies as st +from hypothesis.errors import Unsatisfiable from .. import _array_module as xp from .. import array_helpers as ah @@ -15,7 +17,7 @@ UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes) pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] -@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes)) +@given(hh.mutually_promotable_dtypes(dtypes=dh.real_float_dtypes)) def test_mutually_promotable_dtypes(pair): assert pair in ( (xp.float32, xp.float32), @@ -144,3 +146,27 @@ def test_symmetric_matrices(m, dtype, finite): def test_positive_definite_matrices(m, dtype): assert m.dtype == dtype # TODO: Test that it actually is positive definite + + +def make_raising_func(cls: Type[Exception], msg: str): + def raises(): + raise cls(msg) + + return raises + +@pytest.mark.parametrize( + "func", + [ + make_raising_func(OverflowError, "foo"), + make_raising_func(RuntimeError, "Overflow when unpacking long"), + make_raising_func(Exception, "Got an overflow"), + ] +) +def test_reject_overflow(func): + @given(data=st.data()) + def test_case(data): + with hh.reject_overflow(): + func() + + with pytest.raises(Unsatisfiable): + test_case() diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index deeab264..dbd99495 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,5 +1,5 @@ import pytest -from hypothesis import given, reject +from hypothesis import given from hypothesis import strategies as st from .. import _array_module as xp @@ -105,10 +105,8 @@ def test_fmt_idx(idx, expected): @given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes()) def test_int_to_dtype(x, dtype): - try: + with hh.reject_overflow(): d = xp.asarray(x, dtype=dtype) - except OverflowError: - reject() assert mock_int_dtype(x, dtype) == d diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 0d354c0b..461532e7 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -446,7 +446,7 @@ def assert_array_elements( dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" - if out.dtype in dh.float_dtypes: + if out.dtype in dh.real_float_dtypes: for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 7569d009..42c0aef0 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -13,6 +13,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps +from ._array_module import mod as _xp from .typing import DataType, Index, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci @@ -241,21 +242,27 @@ def test_setitem_masking(shape, data): ) -def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param: +def make_scalar_casting_param( + method_name: str, dtype_name: DataType, stype: ScalarType +) -> Param: return pytest.param( - method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})" + method_name, dtype_name, stype, id=f"{method_name}({dtype_name})" ) @pytest.mark.parametrize( - "method_name, dtype, stype", - [make_param("__bool__", xp.bool, bool)] - + [make_param("__int__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)] - + [make_param("__index__", d, int) for d in dh._filter_stubs(*dh.all_int_dtypes)] - + [make_param("__float__", d, float) for d in dh.float_dtypes], + "method_name, dtype_name, stype", + [make_scalar_casting_param("__bool__", "bool", bool)] + + [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names] + + [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names] + + [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_names], ) @given(data=st.data()) -def test_scalar_casting(method_name, dtype, stype, data): +def test_scalar_casting(method_name, dtype_name, stype, data): + try: + dtype = getattr(_xp, dtype_name) + except AttributeError as e: + pytest.skip(str(e)) x = data.draw(xps.arrays(dtype, shape=()), label="x") method = getattr(x, method_name) out = method() diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index cc6acbbe..1c2a24f7 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -391,7 +391,8 @@ def full_fill_values(draw) -> st.SearchStrategy[Union[bool, int, float, complex] kw=st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw"), ) def test_full(shape, fill_value, kw): - out = xp.full(shape, fill_value, **kw) + with hh.reject_overflow(): + out = xp.full(shape, fill_value, **kw) if kw.get("dtype", None): dtype = kw["dtype"] elif isinstance(fill_value, bool): diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index dc30ce7b..917b1f26 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -11,6 +11,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps +from ._array_module import mod as _xp from .typing import DataType pytestmark = pytest.mark.ci @@ -123,7 +124,7 @@ def test_can_cast(_from, to, data): expected = to == xp.bool else: same_family = None - for dtypes in [dh.all_int_dtypes, dh.float_dtypes, dh.complex_dtypes]: + for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]: if _from in dtypes: same_family = to in dtypes break @@ -141,12 +142,12 @@ def test_can_cast(_from, to, data): assert out == expected, f"{out=}, but should be {expected} {f_func}" -def make_dtype_id(dtype: DataType) -> str: - return dh.dtype_to_name[dtype] - - -@pytest.mark.parametrize("dtype", dh.float_dtypes, ids=make_dtype_id) -def test_finfo(dtype): +@pytest.mark.parametrize("dtype_name", dh.real_float_names) +def test_finfo(dtype_name): + try: + dtype = getattr(_xp, dtype_name) + except AttributeError as e: + pytest.skip(str(e)) out = xp.finfo(dtype) f_func = f"[finfo({dh.dtype_to_name[dtype]})]" for attr, stype in [ @@ -164,8 +165,12 @@ def test_finfo(dtype): # TODO: test values -@pytest.mark.parametrize("dtype", dh._filter_stubs(*dh.all_int_dtypes), ids=make_dtype_id) -def test_iinfo(dtype): +@pytest.mark.parametrize("dtype_name", dh.all_int_names) +def test_iinfo(dtype_name): + try: + dtype = getattr(_xp, dtype_name) + except AttributeError as e: + pytest.skip(str(e)) out = xp.iinfo(dtype) f_func = f"[iinfo({dh.dtype_to_name[dtype]})]" for attr in ["bits", "max", "min"]: diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 75b9105e..0974805e 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -265,6 +265,7 @@ def test_eigvalsh(x): # TODO: Test that res actually corresponds to the eigenvalues of x +@pytest.mark.skip(reason="flaky") @pytest.mark.xp_extension('linalg') @given(x=invertible_matrices()) def test_inv(x): diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index ec519ebb..4d803bb0 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -8,7 +8,7 @@ from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union import pytest -from hypothesis import assume, given, reject +from hypothesis import assume, given from hypothesis import strategies as st from . import _array_module as xp, api_version @@ -740,10 +740,8 @@ def test_add(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - try: + with hh.reject_overflow(): res = ctx.func(left, right) - except OverflowError: - reject() binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) @@ -776,7 +774,7 @@ def test_atan(x): unary_assert_against_refimpl("atan", x, out, math.atan) -@given(*hh.two_mutual_arrays(dh.float_dtypes)) +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_atan2(x1, x2): out = xp.atan2(x1, x2) ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) @@ -1204,7 +1202,7 @@ def logaddexp(l: float, r: float) -> float: return math.log(math.exp(l) + math.exp(r)) -@given(*hh.two_mutual_arrays(dh.float_dtypes)) +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) @@ -1327,10 +1325,8 @@ def test_pow(ctx, data): if dh.is_int_dtype(right.dtype): assume(xp.all(right >= 0)) - try: + with hh.reject_overflow(): res = ctx.func(left, right) - except OverflowError: - reject() binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) @@ -1425,10 +1421,8 @@ def test_subtract(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - try: + with hh.reject_overflow(): res = ctx.func(left, right) - except OverflowError: - reject() binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index b09e1379..06837eee 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -16,7 +16,7 @@ @given( x=xps.arrays( - dtype=xps.numeric_dtypes(), + dtype=xps.real_dtypes(), shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), @@ -53,7 +53,7 @@ def test_argmax(x, data): @given( x=xps.arrays( - dtype=xps.numeric_dtypes(), + dtype=xps.real_dtypes(), shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index 14e65802..67b6ff50 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -34,7 +34,7 @@ def assert_scalar_in_set( # TODO: Test with signed zeros and NaNs (and ignore them somehow) @given( x=xps.arrays( - dtype=xps.scalar_dtypes(), + dtype=xps.real_dtypes(), shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), @@ -94,7 +94,7 @@ def test_argsort(x, data): # TODO: Test with signed zeros and NaNs (and ignore them somehow) @given( x=xps.arrays( - dtype=xps.scalar_dtypes(), + dtype=xps.real_dtypes(), shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 4aefb5b9..345d7fe5 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1231,7 +1231,7 @@ def test_unary(func_name, func, case, x, data): x1_strat, x2_strat = hh.two_mutual_arrays( - dtypes=dh.float_dtypes, + dtypes=dh.real_float_dtypes, two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), ) @@ -1277,7 +1277,7 @@ def test_binary(func_name, func, case, x1, x2, data): @pytest.mark.parametrize("iop_name, iop, case", iop_params) @given( - oneway_dtypes=hh.oneway_promotable_dtypes(dh.float_dtypes), + oneway_dtypes=hh.oneway_promotable_dtypes(dh.real_float_dtypes), oneway_shapes=hh.oneway_broadcastable_shapes(), data=st.data(), ) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index b4c92590..990ae5c7 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -5,7 +5,6 @@ import pytest from hypothesis import assume, given from hypothesis import strategies as st -from hypothesis.control import reject from . import _array_module as xp from . import dtype_helpers as dh @@ -28,7 +27,7 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: @given( x=xps.arrays( - dtype=xps.numeric_dtypes(), + dtype=xps.real_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), @@ -80,7 +79,7 @@ def test_mean(x, data): @given( x=xps.arrays( - dtype=xps.numeric_dtypes(), + dtype=xps.real_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), @@ -127,10 +126,8 @@ def test_prod(x, data): ) keepdims = kw.get("keepdims", False) - try: + with hh.reject_overflow(): out = xp.prod(x, **kw) - except OverflowError: - reject() dtype = kw.get("dtype", None) if dtype is None: @@ -139,12 +136,15 @@ def test_prod(x, data): default_dtype = dh.default_uint else: default_dtype = dh.default_int - m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[default_dtype] - if m < d_m or M > d_M: - _dtype = x.dtype + if default_dtype is None: + _dtype = None else: - _dtype = default_dtype + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = default_dtype else: if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype @@ -152,11 +152,11 @@ def test_prod(x, data): _dtype = dh.default_float else: _dtype = dtype - if isinstance(_dtype, _UndefinedStub): + if _dtype is None: # If a default uint cannot exist (i.e. in PyTorch which doesn't support # uint32 or uint64), we skip testing the output dtype. # See https://github.com/data-apis/array-api-tests/issues/106 - if _dtype in dh.uint_dtypes: + if x.dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check else: ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype) @@ -234,10 +234,8 @@ def test_sum(x, data): ) keepdims = kw.get("keepdims", False) - try: + with hh.reject_overflow(): out = xp.sum(x, **kw) - except OverflowError: - reject() dtype = kw.get("dtype", None) if dtype is None: @@ -246,12 +244,15 @@ def test_sum(x, data): default_dtype = dh.default_uint else: default_dtype = dh.default_int - m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[default_dtype] - if m < d_m or M > d_M: - _dtype = x.dtype + if default_dtype is None: + _dtype = None else: - _dtype = default_dtype + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = default_dtype else: if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype @@ -259,11 +260,11 @@ def test_sum(x, data): _dtype = dh.default_float else: _dtype = dtype - if isinstance(_dtype, _UndefinedStub): + if _dtype is None: # If a default uint cannot exist (i.e. in PyTorch which doesn't support # uint32 or uint64), we skip testing the output dtype. # See https://github.com/data-apis/array-api-tests/issues/160 - if _dtype in dh.uint_dtypes: + if x.dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check else: ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype)