From a5447419f96d2f42b02459d57a12d145aa946c8e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Apr 2022 11:11:20 +0100 Subject: [PATCH 1/6] `EqualityMapping` class --- array_api_tests/dtype_helpers.py | 30 +++++++++++++++++-- array_api_tests/meta/test_equality_mapping.py | 23 ++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 array_api_tests/meta/test_equality_mapping.py diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 29d87216..ec7a4caf 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,3 +1,4 @@ +from collections import Mapping from functools import lru_cache from typing import NamedTuple, Tuple, Union from warnings import warn @@ -99,8 +100,8 @@ class MinMax(NamedTuple): xp.uint16: MinMax(0, +65_535), xp.uint32: MinMax(0, +4_294_967_295), xp.uint64: MinMax(0, +18_446_744_073_709_551_615), - xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38), - xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308), + xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), + xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), } dtype_nbits = { @@ -404,3 +405,28 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: # i.e. dtype is bool, int, or float f_types.append(type_.__name__) return ", ".join(f_types) + + +class EqualityMapping(Mapping): + def __init__(self, mapping: Mapping): + keys = list(mapping.keys()) + for i, key in enumerate(keys): + if not (key == key): # specifically test __eq__, not __neq__ + raise ValueError("Key {key!r} does not have equality with itself") + other_keys = keys[:] + other_keys.pop(i) + for other_key in other_keys: + if key == other_key: + raise ValueError("Key {key!r} has equality with key {other_key!r}") + self._mapping = mapping + + def __getitem__(self, key): + for k, v in self._mapping.items(): + if key == k: + return v + + def __iter__(self): + return iter(self._mapping) + + def __len__(self): + return len(self._mapping) diff --git a/array_api_tests/meta/test_equality_mapping.py b/array_api_tests/meta/test_equality_mapping.py new file mode 100644 index 00000000..8eba2b03 --- /dev/null +++ b/array_api_tests/meta/test_equality_mapping.py @@ -0,0 +1,23 @@ +import pytest + +from ..dtype_helpers import EqualityMapping + + +def test_raises_on_distinct_eq_key(): + with pytest.raises(ValueError): + EqualityMapping({float("nan"): "foo"}) + + +def test_raises_on_indistinct_eq_keys(): + class AlwaysEq: + def __init__(self, hash): + self._hash = hash + + def __eq__(self, other): + return True + + def __hash__(self): + return self._hash + + with pytest.raises(ValueError): + EqualityMapping({AlwaysEq(0): "foo", AlwaysEq(1): "bar"}) From 22f8b750d4de582ad776e71aaea76d638e921184 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Apr 2022 12:45:08 +0100 Subject: [PATCH 2/6] Fix/test KeyError in `EqualityMapping`, add docstring --- array_api_tests/dtype_helpers.py | 16 +++++++++++++++- array_api_tests/meta/test_equality_mapping.py | 6 ++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index ec7a4caf..e39a09cb 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -408,10 +408,19 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: class EqualityMapping(Mapping): + """ + Mapping that uses equality for indexing + + Typical mappings (e.g. the built-in dict) use hashing for indexing. This + isn't ideal for the Array API, as no __hash__() method is specified for + dtype objects - but __eq__() is! + + See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects + """ def __init__(self, mapping: Mapping): keys = list(mapping.keys()) for i, key in enumerate(keys): - if not (key == key): # specifically test __eq__, not __neq__ + if not (key == key): # specifically checking __eq__, not __neq__ raise ValueError("Key {key!r} does not have equality with itself") other_keys = keys[:] other_keys.pop(i) @@ -424,9 +433,14 @@ def __getitem__(self, key): for k, v in self._mapping.items(): if key == k: return v + else: + raise KeyError(f"{key!r} not found") def __iter__(self): return iter(self._mapping) def __len__(self): return len(self._mapping) + + def __repr__(self): + return f"EqualityMapping({self._mapping!r})" diff --git a/array_api_tests/meta/test_equality_mapping.py b/array_api_tests/meta/test_equality_mapping.py index 8eba2b03..c6209ae1 100644 --- a/array_api_tests/meta/test_equality_mapping.py +++ b/array_api_tests/meta/test_equality_mapping.py @@ -21,3 +21,9 @@ def __hash__(self): with pytest.raises(ValueError): EqualityMapping({AlwaysEq(0): "foo", AlwaysEq(1): "bar"}) + + +def test_key_error(): + mapping = EqualityMapping({"foo": "bar"}) + with pytest.raises(KeyError): + mapping["nonexistent key"] From 50e77750d60c2f7a5f98003e172ca89588e355ec Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Apr 2022 12:54:30 +0100 Subject: [PATCH 3/6] Use `EqualityMapping` for relevant dtype helpers --- array_api_tests/dtype_helpers.py | 161 +++++++++++++++++-------------- 1 file changed, 86 insertions(+), 75 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index e39a09cb..f3e8337f 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -37,6 +37,46 @@ ] +class EqualityMapping(Mapping): + """ + Mapping that uses equality for indexing + + Typical mappings (e.g. the built-in dict) use hashing for indexing. This + isn't ideal for the Array API, as no __hash__() method is specified for + dtype objects - but __eq__() is! + + See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects + """ + + def __init__(self, mapping: Mapping): + keys = list(mapping.keys()) + for i, key in enumerate(keys): + if not (key == key): # specifically checking __eq__, not __neq__ + raise ValueError("Key {key!r} does not have equality with itself") + other_keys = keys[:] + other_keys.pop(i) + for other_key in other_keys: + if key == other_key: + raise ValueError("Key {key!r} has equality with key {other_key!r}") + self._mapping = mapping + + def __getitem__(self, key): + for k, v in self._mapping.items(): + if key == k: + return v + else: + raise KeyError(f"{key!r} not found") + + def __iter__(self): + return iter(self._mapping) + + def __len__(self): + return len(self._mapping) + + def __repr__(self): + return f"EqualityMapping({self._mapping!r})" + + _uint_names = ("uint8", "uint16", "uint32", "uint64") _int_names = ("int8", "int16", "int32", "int64") _float_names = ("float32", "float64") @@ -52,14 +92,16 @@ bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes -dtype_to_name = {getattr(xp, name): name for name in _dtype_names} +dtype_to_name = EqualityMapping({getattr(xp, name): name for name in _dtype_names}) -dtype_to_scalars = { - xp.bool: [bool], - **{d: [int] for d in all_int_dtypes}, - **{d: [int, float] for d in float_dtypes}, -} +dtype_to_scalars = EqualityMapping( + { + xp.bool: [bool], + **{d: [int] for d in all_int_dtypes}, + **{d: [int, float] for d in float_dtypes}, + } +) def is_int_dtype(dtype): @@ -91,31 +133,37 @@ class MinMax(NamedTuple): max: Union[int, float] -dtype_ranges = { - xp.int8: MinMax(-128, +127), - xp.int16: MinMax(-32_768, +32_767), - xp.int32: MinMax(-2_147_483_648, +2_147_483_647), - xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), - xp.uint8: MinMax(0, +255), - xp.uint16: MinMax(0, +65_535), - xp.uint32: MinMax(0, +4_294_967_295), - xp.uint64: MinMax(0, +18_446_744_073_709_551_615), - xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), - xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), -} - -dtype_nbits = { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, -} - - -dtype_signed = { - **{d: True for d in int_dtypes}, - **{d: False for d in uint_dtypes}, -} +dtype_ranges = EqualityMapping( + { + xp.int8: MinMax(-128, +127), + xp.int16: MinMax(-32_768, +32_767), + xp.int32: MinMax(-2_147_483_648, +2_147_483_647), + xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), + xp.uint8: MinMax(0, +255), + xp.uint16: MinMax(0, +65_535), + xp.uint32: MinMax(0, +4_294_967_295), + xp.uint64: MinMax(0, +18_446_744_073_709_551_615), + xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), + xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), + } +) + +dtype_nbits = EqualityMapping( + { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, + } +) + + +dtype_signed = EqualityMapping( + { + **{d: True for d in int_dtypes}, + **{d: False for d in uint_dtypes}, + } +) if isinstance(xp.asarray, _UndefinedStub): @@ -179,11 +227,13 @@ class MinMax(NamedTuple): (xp.float32, xp.float64): xp.float64, (xp.float64, xp.float64): xp.float64, } -promotion_table = { - (xp.bool, xp.bool): xp.bool, - **_numeric_promotions, - **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, -} +promotion_table = EqualityMapping( + { + (xp.bool, xp.bool): xp.bool, + **_numeric_promotions, + **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, + } +) def result_type(*dtypes: DataType): @@ -405,42 +455,3 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: # i.e. dtype is bool, int, or float f_types.append(type_.__name__) return ", ".join(f_types) - - -class EqualityMapping(Mapping): - """ - Mapping that uses equality for indexing - - Typical mappings (e.g. the built-in dict) use hashing for indexing. This - isn't ideal for the Array API, as no __hash__() method is specified for - dtype objects - but __eq__() is! - - See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects - """ - def __init__(self, mapping: Mapping): - keys = list(mapping.keys()) - for i, key in enumerate(keys): - if not (key == key): # specifically checking __eq__, not __neq__ - raise ValueError("Key {key!r} does not have equality with itself") - other_keys = keys[:] - other_keys.pop(i) - for other_key in other_keys: - if key == other_key: - raise ValueError("Key {key!r} has equality with key {other_key!r}") - self._mapping = mapping - - def __getitem__(self, key): - for k, v in self._mapping.items(): - if key == k: - return v - else: - raise KeyError(f"{key!r} not found") - - def __iter__(self): - return iter(self._mapping) - - def __len__(self): - return len(self._mapping) - - def __repr__(self): - return f"EqualityMapping({self._mapping!r})" From 1fa44c7746013ddfc9dc84a4add5da0b54d25cf9 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Apr 2022 17:12:26 +0100 Subject: [PATCH 4/6] Monkey patch `_from_dtype()` to ignore troublesome internal logic --- array_api_tests/__init__.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index b4c0e8ba..27b6fe1e 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,14 +1,12 @@ from functools import wraps from hypothesis import strategies as st -from hypothesis.extra.array_api import make_strategies_namespace +from hypothesis.extra import array_api from ._array_module import mod as _xp __all__ = ["xps"] -xps = make_strategies_namespace(_xp) - # We monkey patch floats() to always disable subnormals as they are out-of-scope @@ -23,5 +21,29 @@ def floats(*a, **kw): st.floats = floats + +# We do the same with xps.from_dtype() - this is not strictly necessary, as +# the underlying floats() will never generate subnormals. We only do this +# because internal logic in xps.from_dtype() assumes xp.finfo() has its +# attributes as scalar floats, which is expected behaviour but disrupts many +# unrelated tests. +try: + __from_dtype = array_api._from_dtype + + @wraps(__from_dtype) + def _from_dtype(*a, **kw): + kw["allow_subnormal"] = False + return __from_dtype(*a, **kw) + + array_api._from_dtype = _from_dtype +except AttributeError: + # Ignore monkey patching if Hypothesis changes the private API + pass + + +xps = array_api.make_strategies_namespace(_xp) + + from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] From 4d1891bc17a4196e116c969666defb839b7eb3eb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Apr 2022 10:50:09 +0100 Subject: [PATCH 5/6] Use `Mapping` from `collections.abc` --- array_api_tests/dtype_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index f3e8337f..cd0bb7f8 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,4 +1,4 @@ -from collections import Mapping +from collections.abc import Mapping from functools import lru_cache from typing import NamedTuple, Tuple, Union from warnings import warn From ed23bfa1ec8d3145753a565171533c6e7d248836 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Apr 2022 14:38:51 +0100 Subject: [PATCH 6/6] Use key-value tuples for `EqualityMapping` as opposed to dicts Drops internal use of hashing via dicts for dtype helpers --- array_api_tests/dtype_helpers.py | 155 +++++++++--------- array_api_tests/meta/test_equality_mapping.py | 14 +- 2 files changed, 86 insertions(+), 83 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index cd0bb7f8..c8e76a90 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from functools import lru_cache -from typing import NamedTuple, Tuple, Union +from typing import Any, NamedTuple, Sequence, Tuple, Union from warnings import warn from . import _array_module as xp @@ -48,8 +48,8 @@ class EqualityMapping(Mapping): See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects """ - def __init__(self, mapping: Mapping): - keys = list(mapping.keys()) + def __init__(self, key_value_pairs: Sequence[Tuple[Any, Any]]): + keys = [k for k, _ in key_value_pairs] for i, key in enumerate(keys): if not (key == key): # specifically checking __eq__, not __neq__ raise ValueError("Key {key!r} does not have equality with itself") @@ -58,23 +58,26 @@ def __init__(self, mapping: Mapping): for other_key in other_keys: if key == other_key: raise ValueError("Key {key!r} has equality with key {other_key!r}") - self._mapping = mapping + self._key_value_pairs = key_value_pairs def __getitem__(self, key): - for k, v in self._mapping.items(): + for k, v in self._key_value_pairs: if key == k: return v else: raise KeyError(f"{key!r} not found") def __iter__(self): - return iter(self._mapping) + return (k for k, _ in self._key_value_pairs) def __len__(self): - return len(self._mapping) + return len(self._key_value_pairs) + + def __str__(self): + return "{" + ", ".join(f"{k!r}: {v!r}" for k, v in self._key_value_pairs) + "}" def __repr__(self): - return f"EqualityMapping({self._mapping!r})" + return f"EqualityMapping({self})" _uint_names = ("uint8", "uint16", "uint32", "uint64") @@ -92,15 +95,15 @@ def __repr__(self): bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes -dtype_to_name = EqualityMapping({getattr(xp, name): name for name in _dtype_names}) +dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names]) dtype_to_scalars = EqualityMapping( - { - xp.bool: [bool], - **{d: [int] for d in all_int_dtypes}, - **{d: [int, float] for d in float_dtypes}, - } + [ + (xp.bool, [bool]), + *[(d, [int]) for d in all_int_dtypes], + *[(d, [int, float]) for d in float_dtypes], + ] ) @@ -134,35 +137,30 @@ class MinMax(NamedTuple): dtype_ranges = EqualityMapping( - { - xp.int8: MinMax(-128, +127), - xp.int16: MinMax(-32_768, +32_767), - xp.int32: MinMax(-2_147_483_648, +2_147_483_647), - xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), - xp.uint8: MinMax(0, +255), - xp.uint16: MinMax(0, +65_535), - xp.uint32: MinMax(0, +4_294_967_295), - xp.uint64: MinMax(0, +18_446_744_073_709_551_615), - xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), - xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), - } + [ + (xp.int8, MinMax(-128, +127)), + (xp.int16, MinMax(-32_768, +32_767)), + (xp.int32, MinMax(-2_147_483_648, +2_147_483_647)), + (xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)), + (xp.uint8, MinMax(0, +255)), + (xp.uint16, MinMax(0, +65_535)), + (xp.uint32, MinMax(0, +4_294_967_295)), + (xp.uint64, MinMax(0, +18_446_744_073_709_551_615)), + (xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)), + (xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)), + ] ) dtype_nbits = EqualityMapping( - { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, - } + [(d, 8) for d in [xp.int8, xp.uint8]] + + [(d, 16) for d in [xp.int16, xp.uint16]] + + [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]] + + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]] ) dtype_signed = EqualityMapping( - { - **{d: True for d in int_dtypes}, - **{d: False for d in uint_dtypes}, - } + [(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes] ) @@ -186,54 +184,51 @@ class MinMax(NamedTuple): default_uint = xp.uint64 -_numeric_promotions = { +_numeric_promotions = [ # ints - (xp.int8, xp.int8): xp.int8, - (xp.int8, xp.int16): xp.int16, - (xp.int8, xp.int32): xp.int32, - (xp.int8, xp.int64): xp.int64, - (xp.int16, xp.int16): xp.int16, - (xp.int16, xp.int32): xp.int32, - (xp.int16, xp.int64): xp.int64, - (xp.int32, xp.int32): xp.int32, - (xp.int32, xp.int64): xp.int64, - (xp.int64, xp.int64): xp.int64, + ((xp.int8, xp.int8), xp.int8), + ((xp.int8, xp.int16), xp.int16), + ((xp.int8, xp.int32), xp.int32), + ((xp.int8, xp.int64), xp.int64), + ((xp.int16, xp.int16), xp.int16), + ((xp.int16, xp.int32), xp.int32), + ((xp.int16, xp.int64), xp.int64), + ((xp.int32, xp.int32), xp.int32), + ((xp.int32, xp.int64), xp.int64), + ((xp.int64, xp.int64), xp.int64), # uints - (xp.uint8, xp.uint8): xp.uint8, - (xp.uint8, xp.uint16): xp.uint16, - (xp.uint8, xp.uint32): xp.uint32, - (xp.uint8, xp.uint64): xp.uint64, - (xp.uint16, xp.uint16): xp.uint16, - (xp.uint16, xp.uint32): xp.uint32, - (xp.uint16, xp.uint64): xp.uint64, - (xp.uint32, xp.uint32): xp.uint32, - (xp.uint32, xp.uint64): xp.uint64, - (xp.uint64, xp.uint64): xp.uint64, + ((xp.uint8, xp.uint8), xp.uint8), + ((xp.uint8, xp.uint16), xp.uint16), + ((xp.uint8, xp.uint32), xp.uint32), + ((xp.uint8, xp.uint64), xp.uint64), + ((xp.uint16, xp.uint16), xp.uint16), + ((xp.uint16, xp.uint32), xp.uint32), + ((xp.uint16, xp.uint64), xp.uint64), + ((xp.uint32, xp.uint32), xp.uint32), + ((xp.uint32, xp.uint64), xp.uint64), + ((xp.uint64, xp.uint64), xp.uint64), # ints and uints (mixed sign) - (xp.int8, xp.uint8): xp.int16, - (xp.int8, xp.uint16): xp.int32, - (xp.int8, xp.uint32): xp.int64, - (xp.int16, xp.uint8): xp.int16, - (xp.int16, xp.uint16): xp.int32, - (xp.int16, xp.uint32): xp.int64, - (xp.int32, xp.uint8): xp.int32, - (xp.int32, xp.uint16): xp.int32, - (xp.int32, xp.uint32): xp.int64, - (xp.int64, xp.uint8): xp.int64, - (xp.int64, xp.uint16): xp.int64, - (xp.int64, xp.uint32): xp.int64, + ((xp.int8, xp.uint8), xp.int16), + ((xp.int8, xp.uint16), xp.int32), + ((xp.int8, xp.uint32), xp.int64), + ((xp.int16, xp.uint8), xp.int16), + ((xp.int16, xp.uint16), xp.int32), + ((xp.int16, xp.uint32), xp.int64), + ((xp.int32, xp.uint8), xp.int32), + ((xp.int32, xp.uint16), xp.int32), + ((xp.int32, xp.uint32), xp.int64), + ((xp.int64, xp.uint8), xp.int64), + ((xp.int64, xp.uint16), xp.int64), + ((xp.int64, xp.uint32), xp.int64), # floats - (xp.float32, xp.float32): xp.float32, - (xp.float32, xp.float64): xp.float64, - (xp.float64, xp.float64): xp.float64, -} -promotion_table = EqualityMapping( - { - (xp.bool, xp.bool): xp.bool, - **_numeric_promotions, - **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, - } -) + ((xp.float32, xp.float32), xp.float32), + ((xp.float32, xp.float64), xp.float64), + ((xp.float64, xp.float64), xp.float64), +] +_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions] +_promotion_table = list(set(_numeric_promotions)) +_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool)) +promotion_table = EqualityMapping(_promotion_table) def result_type(*dtypes: DataType): diff --git a/array_api_tests/meta/test_equality_mapping.py b/array_api_tests/meta/test_equality_mapping.py index c6209ae1..86fa7e14 100644 --- a/array_api_tests/meta/test_equality_mapping.py +++ b/array_api_tests/meta/test_equality_mapping.py @@ -5,7 +5,7 @@ def test_raises_on_distinct_eq_key(): with pytest.raises(ValueError): - EqualityMapping({float("nan"): "foo"}) + EqualityMapping([(float("nan"), "value")]) def test_raises_on_indistinct_eq_keys(): @@ -20,10 +20,18 @@ def __hash__(self): return self._hash with pytest.raises(ValueError): - EqualityMapping({AlwaysEq(0): "foo", AlwaysEq(1): "bar"}) + EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")]) def test_key_error(): - mapping = EqualityMapping({"foo": "bar"}) + mapping = EqualityMapping([("key", "value")]) with pytest.raises(KeyError): mapping["nonexistent key"] + + +def test_iter(): + mapping = EqualityMapping([("key", "value")]) + it = iter(mapping) + assert next(it) == "key" + with pytest.raises(StopIteration): + next(it)