Skip to content

Commit 05ab7f1

Browse files
authored
Merge pull request #266 from asmeurer/skip_dtypes
Support skipping dtypes by setting ARRAY_API_TESTS_SKIP_DTYPES
2 parents f9022a1 + 9816630 commit 05ab7f1

19 files changed

+202
-189
lines changed

.flake8

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[flake8]
2+
select = F

README.md

+14-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This is the test suite for array libraries adopting the [Python Array API
44
standard](https://data-apis.org/array-api/latest).
55

6-
Keeping full coverage of the spec is an on-going priority as the Array API evolves.
6+
Keeping full coverage of the spec is an on-going priority as the Array API evolves.
77
Feedback and contributions are welcome!
88

99
## Quickstart
@@ -285,6 +285,19 @@ values should result in more rigorous runs. For example, `--max-examples
285285
10_000` may find bugs where default runs don't but will take much longer to
286286
run.
287287

288+
#### Skipping Dtypes
289+
290+
The test suite will automatically skip testing of inessential dtypes if they
291+
are not present on the array module namespace, but dtypes can also be skipped
292+
manually by setting the environment variable `ARRAY_API_TESTS_SKIP_DTYPES` to
293+
a comma separated list of dtypes to skip. For example
294+
295+
```
296+
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 pytest array_api_tests/
297+
```
298+
299+
Note that skipping certain essential dtypes such as `bool` and the default
300+
floating-point dtype is not supported.
288301
289302
## Contributing
290303

array_api_tests/dtype_helpers.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import re
23
from collections import defaultdict
34
from collections.abc import Mapping
@@ -104,9 +105,18 @@ def __repr__(self):
104105
numeric_names = real_names + complex_names
105106
dtype_names = ("bool",) + numeric_names
106107

108+
_skip_dtypes = os.getenv("ARRAY_API_TESTS_SKIP_DTYPES", '')
109+
_skip_dtypes = _skip_dtypes.split(',')
110+
skip_dtypes = []
111+
for dtype in _skip_dtypes:
112+
if dtype and dtype not in dtype_names:
113+
raise ValueError(f"Invalid dtype name in ARRAY_API_TESTS_SKIP_DTYPES: {dtype}")
114+
skip_dtypes.append(dtype)
107115

108116
_name_to_dtype = {}
109117
for name in dtype_names:
118+
if name in skip_dtypes:
119+
continue
110120
try:
111121
dtype = getattr(xp, name)
112122
except AttributeError:
@@ -184,9 +194,9 @@ def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
184194
dtype_value_pairs = []
185195
for name, value in mapping.items():
186196
assert isinstance(name, str) and name in dtype_names # sanity check
187-
try:
188-
dtype = getattr(xp, name)
189-
except AttributeError:
197+
if name in _name_to_dtype:
198+
dtype = _name_to_dtype[name]
199+
else:
190200
continue
191201
dtype_value_pairs.append((dtype, value))
192202
return EqualityMapping(dtype_value_pairs)
@@ -313,9 +323,9 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
313323
else:
314324
default_complex = None
315325
if dtype_nbits[default_int] == 32:
316-
default_uint = getattr(xp, "uint32", None)
326+
default_uint = _name_to_dtype.get("uint32")
317327
else:
318-
default_uint = getattr(xp, "uint64", None)
328+
default_uint = _name_to_dtype.get("uint64")
319329

320330
_promotion_table: Dict[Tuple[str, str], str] = {
321331
("bool", "bool"): "bool",
@@ -366,18 +376,12 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
366376
_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()})
367377
_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = []
368378
for (in_name1, in_name2), res_name in _promotion_table.items():
369-
try:
370-
in_dtype1 = getattr(xp, in_name1)
371-
except AttributeError:
372-
continue
373-
try:
374-
in_dtype2 = getattr(xp, in_name2)
375-
except AttributeError:
376-
continue
377-
try:
378-
res_dtype = getattr(xp, res_name)
379-
except AttributeError:
379+
if in_name1 not in _name_to_dtype or in_name2 not in _name_to_dtype or res_name not in _name_to_dtype:
380380
continue
381+
in_dtype1 = _name_to_dtype[in_name1]
382+
in_dtype2 = _name_to_dtype[in_name2]
383+
res_dtype = _name_to_dtype[res_name]
384+
381385
_promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype))
382386
promotion_table = EqualityMapping(_promotion_table_pairs)
383387

array_api_tests/hypothesis_helpers.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,24 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
174174
return OnewayBroadcastableShapes(input_shape, result_shape)
175175

176176

177+
# Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
178+
# ARRAY_API_TESTS_SKIP_DTYPES
179+
all_dtypes = sampled_from(_sorted_dtypes)
180+
int_dtypes = sampled_from(dh.int_dtypes)
181+
uint_dtypes = sampled_from(dh.uint_dtypes)
182+
real_dtypes = sampled_from(dh.real_dtypes)
183+
# Warning: The hypothesis "floating_dtypes" is what we call
184+
# "real_floating_dtypes"
185+
floating_dtypes = sampled_from(dh.all_float_dtypes)
186+
real_floating_dtypes = sampled_from(dh.real_float_dtypes)
187+
numeric_dtypes = sampled_from(dh.numeric_dtypes)
188+
# Note: this always returns complex dtypes, even if api_version < 2022.12
189+
complex_dtypes = sampled_from(dh.complex_dtypes)
190+
177191
def all_floating_dtypes() -> SearchStrategy[DataType]:
178-
strat = xps.floating_dtypes()
192+
strat = floating_dtypes
179193
if api_version >= "2022.12":
180-
strat |= xps.complex_dtypes()
194+
strat |= complex_dtypes
181195
return strat
182196

183197

@@ -236,7 +250,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
236250

237251
@composite
238252
def finite_matrices(draw, shape=matrix_shapes()):
239-
return draw(arrays(dtype=xps.floating_dtypes(),
253+
return draw(arrays(dtype=floating_dtypes,
240254
shape=shape,
241255
elements=dict(allow_nan=False,
242256
allow_infinity=False)))
@@ -245,7 +259,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
245259
# Should we set a max_value here?
246260
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
247261
rtols = one_of(floats(**_rtol_float_kw),
248-
arrays(dtype=xps.floating_dtypes(),
262+
arrays(dtype=real_floating_dtypes,
249263
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
250264
elements=_rtol_float_kw))
251265

@@ -280,9 +294,9 @@ def mutually_broadcastable_shapes(
280294

281295
two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2)
282296

283-
# Note: This should become hermitian_matrices when complex dtypes are added
297+
# TODO: Add support for complex Hermitian matrices
284298
@composite
285-
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.):
299+
def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.):
286300
shape = draw(square_matrix_shapes)
287301
dtype = draw(dtypes)
288302
if not isinstance(finite, bool):
@@ -297,7 +311,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10
297311
return H
298312

299313
@composite
300-
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
314+
def positive_definite_matrices(draw, dtypes=floating_dtypes):
301315
# For now just generate stacks of identity matrices
302316
# TODO: Generate arbitrary positive definite matrices, for instance, by
303317
# using something like
@@ -310,7 +324,7 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
310324
return broadcast_to(eye(n, dtype=dtype), shape)
311325

312326
@composite
313-
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()):
327+
def invertible_matrices(draw, dtypes=floating_dtypes, stack_shapes=shapes()):
314328
# For now, just generate stacks of diagonal matrices.
315329
stack_shape = draw(stack_shapes)
316330
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),)
@@ -344,7 +358,7 @@ def two_broadcastable_shapes(draw):
344358
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
345359

346360
numeric_arrays = arrays(
347-
dtype=shared(xps.floating_dtypes(), key='dtypes'),
361+
dtype=shared(floating_dtypes, key='dtypes'),
348362
shape=shared(xps.array_shapes(), key='shapes'),
349363
)
350364

@@ -388,7 +402,7 @@ def python_integer_indices(draw, sizes):
388402
def integer_indices(draw, sizes):
389403
# Return either a Python integer or a 0-D array with some integer dtype
390404
idx = draw(python_integer_indices(sizes))
391-
dtype = draw(xps.integer_dtypes() | xps.unsigned_integer_dtypes())
405+
dtype = draw(int_dtypes | uint_dtypes)
392406
m, M = dh.dtype_ranges[dtype]
393407
if m <= idx <= M:
394408
return draw(one_of(just(idx),

array_api_tests/pytest_helpers.py

+28
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,34 @@ def assert_dtype(
137137
assert out_dtype == expected, msg
138138

139139

140+
def assert_float_to_complex_dtype(
141+
func_name: str, *, in_dtype: DataType, out_dtype: DataType
142+
):
143+
if in_dtype == xp.float32:
144+
expected = xp.complex64
145+
else:
146+
assert in_dtype == xp.float64 # sanity check
147+
expected = xp.complex128
148+
assert_dtype(
149+
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
150+
)
151+
152+
153+
def assert_complex_to_float_dtype(
154+
func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype"
155+
):
156+
if in_dtype == xp.complex64:
157+
expected = xp.float32
158+
elif in_dtype == xp.complex128:
159+
expected = xp.float64
160+
else:
161+
assert in_dtype in (xp.float32, xp.float64) # sanity check
162+
expected = in_dtype
163+
assert_dtype(
164+
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name
165+
)
166+
167+
140168
def assert_kw_dtype(
141169
func_name: str,
142170
*,

array_api_tests/test_array_object.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from . import pytest_helpers as ph
1414
from . import shape_helpers as sh
1515
from . import xps
16-
from . import xp as _xp
1716
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
1817

1918

@@ -75,7 +74,7 @@ def get_indexed_axes_and_out_shape(
7574
return tuple(axes_indices), tuple(out_shape)
7675

7776

78-
@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
77+
@given(shape=hh.shapes(), dtype=hh.all_dtypes, data=st.data())
7978
def test_getitem(shape, dtype, data):
8079
zero_sided = any(side == 0 for side in shape)
8180
if zero_sided:
@@ -157,7 +156,7 @@ def test_setitem(shape, dtypes, data):
157156
@pytest.mark.data_dependent_shapes
158157
@given(hh.shapes(), st.data())
159158
def test_getitem_masking(shape, data):
160-
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
159+
x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
161160
mask_shapes = st.one_of(
162161
st.sampled_from([x.shape, ()]),
163162
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
@@ -202,7 +201,7 @@ def test_getitem_masking(shape, data):
202201
@pytest.mark.unvectorized
203202
@given(hh.shapes(), st.data())
204203
def test_setitem_masking(shape, data):
205-
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
204+
x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
206205
key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
207206
value = data.draw(
208207
hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
@@ -252,18 +251,14 @@ def make_scalar_casting_param(
252251

253252

254253
@pytest.mark.parametrize(
255-
"method_name, dtype_name, stype",
256-
[make_scalar_casting_param("__bool__", "bool", bool)]
257-
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names]
258-
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names]
259-
+ [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_names],
254+
"method_name, dtype, stype",
255+
[make_scalar_casting_param("__bool__", xp.bool, bool)]
256+
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_dtypes]
257+
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_dtypes]
258+
+ [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_dtypes],
260259
)
261260
@given(data=st.data())
262-
def test_scalar_casting(method_name, dtype_name, stype, data):
263-
try:
264-
dtype = getattr(_xp, dtype_name)
265-
except AttributeError as e:
266-
pytest.skip(str(e))
261+
def test_scalar_casting(method_name, dtype, stype, data):
267262
x = data.draw(hh.arrays(dtype, shape=()), label="x")
268263
method = getattr(x, method_name)
269264
out = method()

0 commit comments

Comments
 (0)