Skip to content

Commit a533680

Browse files
authored
Merge pull request #165 from honno/complex-support
Versioning support, bulk of complex testing
2 parents 0c2c0f7 + ef0e3b1 commit a533680

20 files changed

+443
-201
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ library to fail.
160160

161161
### Configuration
162162

163+
#### API version
164+
165+
You can specify the API version to use when testing via the
166+
`ARRAY_API_TESTS_VERSION` environment variable. Currently this defaults to the
167+
array module's `__array_api_version__` value, and if that attribute doesn't
168+
exist then we fallback to `"2021.12"`.
169+
163170
#### CI flag
164171

165172
Use the `--ci` flag to run only the primary and special cases tests. You can

array_api_tests/__init__.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from functools import wraps
2+
from os import getenv
23

34
from hypothesis import strategies as st
45
from hypothesis.extra import array_api
56

7+
from . import _version
68
from ._array_module import mod as _xp
79

8-
__all__ = ["xps"]
10+
__all__ = ["api_version", "xps"]
911

1012

1113
# We monkey patch floats() to always disable subnormals as they are out-of-scope
@@ -41,9 +43,9 @@ def _from_dtype(*a, **kw):
4143
pass
4244

4345

44-
xps = array_api.make_strategies_namespace(_xp, api_version="2021.12")
45-
46-
47-
from . import _version
46+
api_version = getenv(
47+
"ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12")
48+
)
49+
xps = array_api.make_strategies_namespace(_xp, api_version=api_version)
4850

4951
__version__ = _version.get_versions()["version"]

array_api_tests/_array_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __repr__(self):
5858
"uint8", "uint16", "uint32", "uint64",
5959
"int8", "int16", "int32", "int64",
6060
"float32", "float64",
61+
"complex64", "complex128",
6162
]
6263
_constants = ["e", "inf", "nan", "pi"]
6364
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]

array_api_tests/dtype_helpers.py

+45-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

8+
from . import api_version
89
from . import _array_module as xp
910
from ._array_module import _UndefinedStub
1011
from .stubs import name_to_func
@@ -15,10 +16,12 @@
1516
"uint_dtypes",
1617
"all_int_dtypes",
1718
"float_dtypes",
19+
"real_dtypes",
1820
"numeric_dtypes",
1921
"all_dtypes",
20-
"dtype_to_name",
22+
"all_float_dtypes",
2123
"bool_and_all_int_dtypes",
24+
"dtype_to_name",
2225
"dtype_to_scalars",
2326
"is_int_dtype",
2427
"is_float_dtype",
@@ -27,9 +30,11 @@
2730
"default_int",
2831
"default_uint",
2932
"default_float",
33+
"default_complex",
3034
"promotion_table",
3135
"dtype_nbits",
3236
"dtype_signed",
37+
"dtype_components",
3338
"func_in_dtypes",
3439
"func_returns_bool",
3540
"binary_op_to_symbol",
@@ -86,15 +91,25 @@ def __repr__(self):
8691
_uint_names = ("uint8", "uint16", "uint32", "uint64")
8792
_int_names = ("int8", "int16", "int32", "int64")
8893
_float_names = ("float32", "float64")
89-
_dtype_names = ("bool",) + _uint_names + _int_names + _float_names
94+
_real_names = _uint_names + _int_names + _float_names
95+
_complex_names = ("complex64", "complex128")
96+
_numeric_names = _real_names + _complex_names
97+
_dtype_names = ("bool",) + _numeric_names
9098

9199

92100
uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
93101
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
94102
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
95103
all_int_dtypes = uint_dtypes + int_dtypes
96-
numeric_dtypes = all_int_dtypes + float_dtypes
104+
real_dtypes = all_int_dtypes + float_dtypes
105+
complex_dtypes = tuple(getattr(xp, name) for name in _complex_names)
106+
numeric_dtypes = real_dtypes
107+
if api_version > "2021.12":
108+
numeric_dtypes += complex_dtypes
97109
all_dtypes = (xp.bool,) + numeric_dtypes
110+
all_float_dtypes = float_dtypes
111+
if api_version > "2021.12":
112+
all_float_dtypes += complex_dtypes
98113
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
99114

100115

@@ -121,14 +136,19 @@ def is_float_dtype(dtype):
121136
# See https://github.com/numpy/numpy/issues/18434
122137
if dtype is None:
123138
return False
124-
return dtype in float_dtypes
139+
valid_dtypes = float_dtypes
140+
if api_version > "2021.12":
141+
valid_dtypes += complex_dtypes
142+
return dtype in valid_dtypes
125143

126144

127145
def get_scalar_type(dtype: DataType) -> ScalarType:
128146
if is_int_dtype(dtype):
129147
return int
130148
elif is_float_dtype(dtype):
131149
return float
150+
elif dtype in complex_dtypes:
151+
return complex
132152
else:
133153
return bool
134154

@@ -157,7 +177,8 @@ class MinMax(NamedTuple):
157177
[(d, 8) for d in [xp.int8, xp.uint8]]
158178
+ [(d, 16) for d in [xp.int16, xp.uint16]]
159179
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
160-
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]]
180+
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]]
181+
+ [(xp.complex128, 128)]
161182
)
162183

163184

@@ -166,6 +187,11 @@ class MinMax(NamedTuple):
166187
)
167188

168189

190+
dtype_components = EqualityMapping(
191+
[(xp.complex64, xp.float32), (xp.complex128, xp.float64)]
192+
)
193+
194+
169195
if isinstance(xp.asarray, _UndefinedStub):
170196
default_int = xp.int32
171197
default_float = xp.float32
@@ -180,6 +206,15 @@ class MinMax(NamedTuple):
180206
default_float = xp.asarray(float()).dtype
181207
if default_float not in float_dtypes:
182208
warn(f"inferred default float is {default_float!r}, which is not a float")
209+
if api_version > "2021.12":
210+
default_complex = xp.asarray(complex()).dtype
211+
if default_complex not in complex_dtypes:
212+
warn(
213+
f"inferred default complex is {default_complex!r}, "
214+
"which is not a complex"
215+
)
216+
else:
217+
default_complex = None
183218
if dtype_nbits[default_int] == 32:
184219
default_uint = xp.uint32
185220
else:
@@ -226,6 +261,11 @@ class MinMax(NamedTuple):
226261
((xp.float32, xp.float32), xp.float32),
227262
((xp.float32, xp.float64), xp.float64),
228263
((xp.float64, xp.float64), xp.float64),
264+
# complex
265+
((xp.complex64, xp.complex64), xp.complex64),
266+
((xp.complex64, xp.complex128), xp.complex128),
267+
((xp.complex128, xp.complex128), xp.complex128),
268+
229269
]
230270
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
231271
_promotion_table = list(set(_numeric_promotions))

array_api_tests/hypothesis_helpers.py

+50-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from operator import mul
55
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
66

7-
from hypothesis import assume
7+
from hypothesis import assume, reject
88
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
99
integers, just, lists, none, one_of,
1010
sampled_from, shared)
@@ -26,27 +26,20 @@
2626
# work for floating point dtypes as those are assumed to be defined in other
2727
# places in the tests.
2828
FILTER_UNDEFINED_DTYPES = True
29+
# TODO: currently we assume this to be true - we probably can remove this completely
30+
assert FILTER_UNDEFINED_DTYPES
2931

30-
integer_dtypes = sampled_from(dh.all_int_dtypes)
31-
floating_dtypes = sampled_from(dh.float_dtypes)
32-
numeric_dtypes = sampled_from(dh.numeric_dtypes)
33-
integer_or_boolean_dtypes = sampled_from(dh.bool_and_all_int_dtypes)
34-
boolean_dtypes = just(xp.bool)
35-
dtypes = sampled_from(dh.all_dtypes)
36-
37-
if FILTER_UNDEFINED_DTYPES:
38-
integer_dtypes = integer_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
39-
floating_dtypes = floating_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
40-
numeric_dtypes = numeric_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
41-
integer_or_boolean_dtypes = integer_or_boolean_dtypes.filter(lambda x: not
42-
isinstance(x, _UndefinedStub))
43-
boolean_dtypes = boolean_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
44-
dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
32+
integer_dtypes = xps.integer_dtypes() | xps.unsigned_integer_dtypes()
33+
floating_dtypes = xps.floating_dtypes()
34+
numeric_dtypes = xps.numeric_dtypes()
35+
integer_or_boolean_dtypes = xps.boolean_dtypes() | integer_dtypes
36+
boolean_dtypes = xps.boolean_dtypes()
37+
dtypes = xps.scalar_dtypes()
4538

4639
shared_dtypes = shared(dtypes, key="dtype")
4740
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
4841

49-
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes]
42+
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes]
5043
_sorted_dtypes = [d for category in _dtype_categories for d in category]
5144

5245
def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
@@ -106,6 +99,46 @@ def mutually_promotable_dtypes(
10699
return one_of(strats).map(tuple)
107100

108101

102+
class OnewayPromotableDtypes(NamedTuple):
103+
input_dtype: DataType
104+
result_dtype: DataType
105+
106+
107+
@composite
108+
def oneway_promotable_dtypes(
109+
draw, dtypes: Sequence[DataType]
110+
) -> SearchStrategy[OnewayPromotableDtypes]:
111+
"""Return a strategy for input dtypes that promote to result dtypes."""
112+
d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes))
113+
result_dtype = dh.result_type(d1, d2)
114+
if d1 == result_dtype:
115+
return OnewayPromotableDtypes(d2, d1)
116+
elif d2 == result_dtype:
117+
return OnewayPromotableDtypes(d1, d2)
118+
else:
119+
reject()
120+
121+
122+
class OnewayBroadcastableShapes(NamedTuple):
123+
input_shape: Shape
124+
result_shape: Shape
125+
126+
127+
@composite
128+
def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShapes]:
129+
"""Return a strategy for input shapes that broadcast to result shapes."""
130+
result_shape = draw(shapes(min_side=1))
131+
input_shape = draw(
132+
xps.broadcastable_shapes(
133+
result_shape,
134+
# Override defaults so bad shapes are less likely to be generated.
135+
max_side=None if result_shape == () else max(result_shape),
136+
max_dims=len(result_shape),
137+
).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
138+
)
139+
return OnewayBroadcastableShapes(input_shape, result_shape)
140+
141+
109142
# shared() allows us to draw either the function or the function name and they
110143
# will both correspond to the same function.
111144

array_api_tests/meta/test_utils.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@
44

55
from .. import _array_module as xp
66
from .. import dtype_helpers as dh
7+
from .. import hypothesis_helpers as hh
78
from .. import shape_helpers as sh
89
from .. import xps
910
from ..test_creation_functions import frange
1011
from ..test_manipulation_functions import roll_ndindex
11-
from ..test_operators_and_elementwise_functions import (
12-
mock_int_dtype,
13-
oneway_broadcastable_shapes,
14-
oneway_promotable_dtypes,
15-
)
12+
from ..test_operators_and_elementwise_functions import mock_int_dtype
1613

1714

1815
@pytest.mark.parametrize(
@@ -115,11 +112,11 @@ def test_int_to_dtype(x, dtype):
115112
assert mock_int_dtype(x, dtype) == d
116113

117114

118-
@given(oneway_promotable_dtypes(dh.all_dtypes))
115+
@given(hh.oneway_promotable_dtypes(dh.all_dtypes))
119116
def test_oneway_promotable_dtypes(D):
120117
assert D.result_dtype == dh.result_type(*D)
121118

122119

123-
@given(oneway_broadcastable_shapes())
120+
@given(hh.oneway_broadcastable_shapes())
124121
def test_oneway_broadcastable_shapes(S):
125122
assert S.result_shape == sh.broadcast_shapes(*S)

0 commit comments

Comments
 (0)