Skip to content

Commit d114745

Browse files
committed
Test case for PyTorch using mutually_promotable_dtypes
1 parent e437c43 commit d114745

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
from hypothesis import given
3+
4+
from .. import dtype_helpers as dh
5+
from .. import hypothesis_helpers as hh
6+
from .. import _array_module as xp
7+
from .._array_module import _UndefinedStub
8+
9+
10+
# e.g. PyTorch only supports uint8 currently
11+
@pytest.mark.skipif(isinstance(xp.uint8, _UndefinedStub), reason="uint8 not defined")
12+
@pytest.mark.skipif(
13+
not all(isinstance(d, _UndefinedStub) for d in dh.uint_dtypes[1:]),
14+
reason="uints defined",
15+
)
16+
@given(hh.mutually_promotable_dtypes(dtypes=dh.uint_dtypes))
17+
def test_mutually_promotable_dtypes(pair):
18+
assert pair == (xp.uint8, xp.uint8)

0 commit comments

Comments
 (0)