Skip to content

Commit 7f8a19d

Browse files
committed
Raise TypeError on bad dtypes in hh.two_mutual_arrays()
1 parent 9c9ffe1 commit 7f8a19d

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .array_helpers import ndindex
1919
from .function_stubs import elementwise_functions
2020
from .pytest_helpers import nargs
21-
from .typing import DataType, Shape
21+
from .typing import DataType, Shape, Array
2222

2323
# Set this to True to not fail tests just because a dtype isn't implemented.
2424
# If no compatible dtype is implemented for a given test, the test will fail
@@ -344,7 +344,9 @@ def multiaxis_indices(draw, shapes):
344344
def two_mutual_arrays(
345345
dtypes: Sequence[DataType] = dh.all_dtypes,
346346
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
347-
) -> SearchStrategy:
347+
) -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]:
348+
if not isinstance(dtypes, Sequence):
349+
raise TypeError(f"{dtypes=} not a sequence")
348350
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
349351
mutual_shapes = shared(two_shapes)
350352
arrays1 = xps.arrays(

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def test_two_mutual_arrays(x1, x2):
7070
assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape)
7171

7272

73+
def test_two_mutual_arrays_raises_on_bad_dtypes():
74+
with pytest.raises(TypeError):
75+
hh.two_mutual_arrays(dtypes=xps.scalar_dtypes())
76+
77+
7378
def test_kwargs():
7479
results = []
7580

0 commit comments

Comments
 (0)