Skip to content

Filter undefined dtypes in hh.mutually_promotable_dtypes() #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions array_api_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from hypothesis.extra.array_api import make_strategies_namespace

from . import _array_module as xp
from ._array_module import mod as _xp


xps = make_strategies_namespace(xp)
xps = make_strategies_namespace(_xp)


del _xp
del make_strategies_namespace
17 changes: 6 additions & 11 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import reduce
from math import sqrt
from operator import mul
from typing import Any, List, NamedTuple, Optional, Tuple
from typing import Any, List, NamedTuple, Optional, Tuple, Sequence

from hypothesis import assume
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
Expand Down Expand Up @@ -68,19 +68,14 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):

promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)

if FILTER_UNDEFINED_DTYPES:
promotable_dtypes = [
(i, j) for i, j in promotable_dtypes
if not isinstance(i, _UndefinedStub)
and not isinstance(j, _UndefinedStub)
]


def mutually_promotable_dtypes(
max_size: Optional[int] = 2,
*,
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
dtypes: Sequence[DataType] = dh.all_dtypes,
) -> SearchStrategy[Tuple[DataType, ...]]:
if FILTER_UNDEFINED_DTYPES:
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
assert len(dtypes) > 0, "all dtypes undefined" # sanity check
if max_size == 2:
return sampled_from(
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
Expand Down Expand Up @@ -347,7 +342,7 @@ def multiaxis_indices(draw, shapes):


def two_mutual_arrays(
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
dtypes: Sequence[DataType] = dh.all_dtypes,
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
) -> SearchStrategy:
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
Expand Down
18 changes: 16 additions & 2 deletions array_api_tests/meta/test_hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,29 @@
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]

@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
def test_mutually_promotable_dtypes(pairs):
assert pairs in (
def test_mutually_promotable_dtypes(pair):
assert pair in (
(xp.float32, xp.float32),
(xp.float32, xp.float64),
(xp.float64, xp.float32),
(xp.float64, xp.float64),
)


@given(
hh.mutually_promotable_dtypes(
dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32]
)
)
def test_partial_mutually_promotable_dtypes(pair):
assert pair in (
(xp.uint8, xp.uint8),
(xp.uint8, xp.uint32),
(xp.uint32, xp.uint8),
(xp.uint32, xp.uint32),
)


def valid_shape(shape) -> bool:
return (
all(isinstance(side, int) for side in shape)
Expand Down
18 changes: 18 additions & 0 deletions array_api_tests/meta/test_partial_adopters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from hypothesis import given

from .. import dtype_helpers as dh
from .. import hypothesis_helpers as hh
from .. import _array_module as xp
from .._array_module import _UndefinedStub


# e.g. PyTorch only supports uint8 currently
@pytest.mark.skipif(isinstance(xp.uint8, _UndefinedStub), reason="uint8 not defined")
@pytest.mark.skipif(
not all(isinstance(d, _UndefinedStub) for d in dh.uint_dtypes[1:]),
reason="uints defined",
)
@given(hh.mutually_promotable_dtypes(dtypes=dh.uint_dtypes))
def test_mutually_promotable_dtypes(pair):
assert pair == (xp.uint8, xp.uint8)