diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index b4c0e8ba..27b6fe1e 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,14 +1,12 @@ from functools import wraps from hypothesis import strategies as st -from hypothesis.extra.array_api import make_strategies_namespace +from hypothesis.extra import array_api from ._array_module import mod as _xp __all__ = ["xps"] -xps = make_strategies_namespace(_xp) - # We monkey patch floats() to always disable subnormals as they are out-of-scope @@ -23,5 +21,29 @@ def floats(*a, **kw): st.floats = floats + +# We do the same with xps.from_dtype() - this is not strictly necessary, as +# the underlying floats() will never generate subnormals. We only do this +# because internal logic in xps.from_dtype() assumes xp.finfo() has its +# attributes as scalar floats, which is expected behaviour but disrupts many +# unrelated tests. +try: + __from_dtype = array_api._from_dtype + + @wraps(__from_dtype) + def _from_dtype(*a, **kw): + kw["allow_subnormal"] = False + return __from_dtype(*a, **kw) + + array_api._from_dtype = _from_dtype +except AttributeError: + # Ignore monkey patching if Hypothesis changes the private API + pass + + +xps = array_api.make_strategies_namespace(_xp) + + from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 29d87216..c8e76a90 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from functools import lru_cache -from typing import NamedTuple, Tuple, Union +from typing import Any, NamedTuple, Sequence, Tuple, Union from warnings import warn from . import _array_module as xp @@ -36,6 +37,49 @@ ] +class EqualityMapping(Mapping): + """ + Mapping that uses equality for indexing + + Typical mappings (e.g. the built-in dict) use hashing for indexing. This + isn't ideal for the Array API, as no __hash__() method is specified for + dtype objects - but __eq__() is! + + See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects + """ + + def __init__(self, key_value_pairs: Sequence[Tuple[Any, Any]]): + keys = [k for k, _ in key_value_pairs] + for i, key in enumerate(keys): + if not (key == key): # specifically checking __eq__, not __neq__ + raise ValueError("Key {key!r} does not have equality with itself") + other_keys = keys[:] + other_keys.pop(i) + for other_key in other_keys: + if key == other_key: + raise ValueError("Key {key!r} has equality with key {other_key!r}") + self._key_value_pairs = key_value_pairs + + def __getitem__(self, key): + for k, v in self._key_value_pairs: + if key == k: + return v + else: + raise KeyError(f"{key!r} not found") + + def __iter__(self): + return (k for k, _ in self._key_value_pairs) + + def __len__(self): + return len(self._key_value_pairs) + + def __str__(self): + return "{" + ", ".join(f"{k!r}: {v!r}" for k, v in self._key_value_pairs) + "}" + + def __repr__(self): + return f"EqualityMapping({self})" + + _uint_names = ("uint8", "uint16", "uint32", "uint64") _int_names = ("int8", "int16", "int32", "int64") _float_names = ("float32", "float64") @@ -51,14 +95,16 @@ bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes -dtype_to_name = {getattr(xp, name): name for name in _dtype_names} +dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names]) -dtype_to_scalars = { - xp.bool: [bool], - **{d: [int] for d in all_int_dtypes}, - **{d: [int, float] for d in float_dtypes}, -} +dtype_to_scalars = EqualityMapping( + [ + (xp.bool, [bool]), + *[(d, [int]) for d in all_int_dtypes], + *[(d, [int, float]) for d in float_dtypes], + ] +) def is_int_dtype(dtype): @@ -90,31 +136,32 @@ class MinMax(NamedTuple): max: Union[int, float] -dtype_ranges = { - xp.int8: MinMax(-128, +127), - xp.int16: MinMax(-32_768, +32_767), - xp.int32: MinMax(-2_147_483_648, +2_147_483_647), - xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), - xp.uint8: MinMax(0, +255), - xp.uint16: MinMax(0, +65_535), - xp.uint32: MinMax(0, +4_294_967_295), - xp.uint64: MinMax(0, +18_446_744_073_709_551_615), - xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38), - xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308), -} +dtype_ranges = EqualityMapping( + [ + (xp.int8, MinMax(-128, +127)), + (xp.int16, MinMax(-32_768, +32_767)), + (xp.int32, MinMax(-2_147_483_648, +2_147_483_647)), + (xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)), + (xp.uint8, MinMax(0, +255)), + (xp.uint16, MinMax(0, +65_535)), + (xp.uint32, MinMax(0, +4_294_967_295)), + (xp.uint64, MinMax(0, +18_446_744_073_709_551_615)), + (xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)), + (xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)), + ] +) -dtype_nbits = { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, -} +dtype_nbits = EqualityMapping( + [(d, 8) for d in [xp.int8, xp.uint8]] + + [(d, 16) for d in [xp.int16, xp.uint16]] + + [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]] + + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]] +) -dtype_signed = { - **{d: True for d in int_dtypes}, - **{d: False for d in uint_dtypes}, -} +dtype_signed = EqualityMapping( + [(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes] +) if isinstance(xp.asarray, _UndefinedStub): @@ -137,52 +184,51 @@ class MinMax(NamedTuple): default_uint = xp.uint64 -_numeric_promotions = { +_numeric_promotions = [ # ints - (xp.int8, xp.int8): xp.int8, - (xp.int8, xp.int16): xp.int16, - (xp.int8, xp.int32): xp.int32, - (xp.int8, xp.int64): xp.int64, - (xp.int16, xp.int16): xp.int16, - (xp.int16, xp.int32): xp.int32, - (xp.int16, xp.int64): xp.int64, - (xp.int32, xp.int32): xp.int32, - (xp.int32, xp.int64): xp.int64, - (xp.int64, xp.int64): xp.int64, + ((xp.int8, xp.int8), xp.int8), + ((xp.int8, xp.int16), xp.int16), + ((xp.int8, xp.int32), xp.int32), + ((xp.int8, xp.int64), xp.int64), + ((xp.int16, xp.int16), xp.int16), + ((xp.int16, xp.int32), xp.int32), + ((xp.int16, xp.int64), xp.int64), + ((xp.int32, xp.int32), xp.int32), + ((xp.int32, xp.int64), xp.int64), + ((xp.int64, xp.int64), xp.int64), # uints - (xp.uint8, xp.uint8): xp.uint8, - (xp.uint8, xp.uint16): xp.uint16, - (xp.uint8, xp.uint32): xp.uint32, - (xp.uint8, xp.uint64): xp.uint64, - (xp.uint16, xp.uint16): xp.uint16, - (xp.uint16, xp.uint32): xp.uint32, - (xp.uint16, xp.uint64): xp.uint64, - (xp.uint32, xp.uint32): xp.uint32, - (xp.uint32, xp.uint64): xp.uint64, - (xp.uint64, xp.uint64): xp.uint64, + ((xp.uint8, xp.uint8), xp.uint8), + ((xp.uint8, xp.uint16), xp.uint16), + ((xp.uint8, xp.uint32), xp.uint32), + ((xp.uint8, xp.uint64), xp.uint64), + ((xp.uint16, xp.uint16), xp.uint16), + ((xp.uint16, xp.uint32), xp.uint32), + ((xp.uint16, xp.uint64), xp.uint64), + ((xp.uint32, xp.uint32), xp.uint32), + ((xp.uint32, xp.uint64), xp.uint64), + ((xp.uint64, xp.uint64), xp.uint64), # ints and uints (mixed sign) - (xp.int8, xp.uint8): xp.int16, - (xp.int8, xp.uint16): xp.int32, - (xp.int8, xp.uint32): xp.int64, - (xp.int16, xp.uint8): xp.int16, - (xp.int16, xp.uint16): xp.int32, - (xp.int16, xp.uint32): xp.int64, - (xp.int32, xp.uint8): xp.int32, - (xp.int32, xp.uint16): xp.int32, - (xp.int32, xp.uint32): xp.int64, - (xp.int64, xp.uint8): xp.int64, - (xp.int64, xp.uint16): xp.int64, - (xp.int64, xp.uint32): xp.int64, + ((xp.int8, xp.uint8), xp.int16), + ((xp.int8, xp.uint16), xp.int32), + ((xp.int8, xp.uint32), xp.int64), + ((xp.int16, xp.uint8), xp.int16), + ((xp.int16, xp.uint16), xp.int32), + ((xp.int16, xp.uint32), xp.int64), + ((xp.int32, xp.uint8), xp.int32), + ((xp.int32, xp.uint16), xp.int32), + ((xp.int32, xp.uint32), xp.int64), + ((xp.int64, xp.uint8), xp.int64), + ((xp.int64, xp.uint16), xp.int64), + ((xp.int64, xp.uint32), xp.int64), # floats - (xp.float32, xp.float32): xp.float32, - (xp.float32, xp.float64): xp.float64, - (xp.float64, xp.float64): xp.float64, -} -promotion_table = { - (xp.bool, xp.bool): xp.bool, - **_numeric_promotions, - **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, -} + ((xp.float32, xp.float32), xp.float32), + ((xp.float32, xp.float64), xp.float64), + ((xp.float64, xp.float64), xp.float64), +] +_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions] +_promotion_table = list(set(_numeric_promotions)) +_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool)) +promotion_table = EqualityMapping(_promotion_table) def result_type(*dtypes: DataType): diff --git a/array_api_tests/meta/test_equality_mapping.py b/array_api_tests/meta/test_equality_mapping.py new file mode 100644 index 00000000..86fa7e14 --- /dev/null +++ b/array_api_tests/meta/test_equality_mapping.py @@ -0,0 +1,37 @@ +import pytest + +from ..dtype_helpers import EqualityMapping + + +def test_raises_on_distinct_eq_key(): + with pytest.raises(ValueError): + EqualityMapping([(float("nan"), "value")]) + + +def test_raises_on_indistinct_eq_keys(): + class AlwaysEq: + def __init__(self, hash): + self._hash = hash + + def __eq__(self, other): + return True + + def __hash__(self): + return self._hash + + with pytest.raises(ValueError): + EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")]) + + +def test_key_error(): + mapping = EqualityMapping([("key", "value")]) + with pytest.raises(KeyError): + mapping["nonexistent key"] + + +def test_iter(): + mapping = EqualityMapping([("key", "value")]) + it = iter(mapping) + assert next(it) == "key" + with pytest.raises(StopIteration): + next(it)