Skip to content

Commit 58af424

Browse files
authored
Merge pull request #34 from honno/partial-dtypes
Filter undefined dtypes in `hh.mutually_promotable_dtypes()`
2 parents 0574111 + 0938424 commit 58af424

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

array_api_tests/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from hypothesis.extra.array_api import make_strategies_namespace
22

3-
from . import _array_module as xp
3+
from ._array_module import mod as _xp
44

55

6-
xps = make_strategies_namespace(xp)
6+
xps = make_strategies_namespace(_xp)
7+
8+
9+
del _xp
10+
del make_strategies_namespace

array_api_tests/hypothesis_helpers.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import reduce
33
from math import sqrt
44
from operator import mul
5-
from typing import Any, List, NamedTuple, Optional, Tuple
5+
from typing import Any, List, NamedTuple, Optional, Tuple, Sequence
66

77
from hypothesis import assume
88
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
@@ -68,19 +68,14 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
6868

6969
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
7070

71-
if FILTER_UNDEFINED_DTYPES:
72-
promotable_dtypes = [
73-
(i, j) for i, j in promotable_dtypes
74-
if not isinstance(i, _UndefinedStub)
75-
and not isinstance(j, _UndefinedStub)
76-
]
77-
78-
7971
def mutually_promotable_dtypes(
8072
max_size: Optional[int] = 2,
8173
*,
82-
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
74+
dtypes: Sequence[DataType] = dh.all_dtypes,
8375
) -> SearchStrategy[Tuple[DataType, ...]]:
76+
if FILTER_UNDEFINED_DTYPES:
77+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
78+
assert len(dtypes) > 0, "all dtypes undefined" # sanity check
8479
if max_size == 2:
8580
return sampled_from(
8681
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
@@ -347,7 +342,7 @@ def multiaxis_indices(draw, shapes):
347342

348343

349344
def two_mutual_arrays(
350-
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
345+
dtypes: Sequence[DataType] = dh.all_dtypes,
351346
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
352347
) -> SearchStrategy:
353348
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,29 @@
1515
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1616

1717
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
18-
def test_mutually_promotable_dtypes(pairs):
19-
assert pairs in (
18+
def test_mutually_promotable_dtypes(pair):
19+
assert pair in (
2020
(xp.float32, xp.float32),
2121
(xp.float32, xp.float64),
2222
(xp.float64, xp.float32),
2323
(xp.float64, xp.float64),
2424
)
2525

2626

27+
@given(
28+
hh.mutually_promotable_dtypes(
29+
dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32]
30+
)
31+
)
32+
def test_partial_mutually_promotable_dtypes(pair):
33+
assert pair in (
34+
(xp.uint8, xp.uint8),
35+
(xp.uint8, xp.uint32),
36+
(xp.uint32, xp.uint8),
37+
(xp.uint32, xp.uint32),
38+
)
39+
40+
2741
def valid_shape(shape) -> bool:
2842
return (
2943
all(isinstance(side, int) for side in shape)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
from hypothesis import given
3+
4+
from .. import dtype_helpers as dh
5+
from .. import hypothesis_helpers as hh
6+
from .. import _array_module as xp
7+
from .._array_module import _UndefinedStub
8+
9+
10+
# e.g. PyTorch only supports uint8 currently
11+
@pytest.mark.skipif(isinstance(xp.uint8, _UndefinedStub), reason="uint8 not defined")
12+
@pytest.mark.skipif(
13+
not all(isinstance(d, _UndefinedStub) for d in dh.uint_dtypes[1:]),
14+
reason="uints defined",
15+
)
16+
@given(hh.mutually_promotable_dtypes(dtypes=dh.uint_dtypes))
17+
def test_mutually_promotable_dtypes(pair):
18+
assert pair == (xp.uint8, xp.uint8)

0 commit comments

Comments
 (0)