Skip to content

Commit 9709161

Browse files
committed
Fixed test_full_like
1 parent 8b11476 commit 9709161

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

array_api_tests/hypothesis_helpers.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
shared, floats, just, composite, one_of,
77
none, booleans)
88
from hypothesis.extra.array_api import make_strategies_namespace
9-
from hypothesis import assume
109

1110
from .pytest_helpers import nargs
1211
from .array_helpers import (dtype_ranges, integer_dtype_objects,

array_api_tests/test_creation_functions.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
scalars, xps, kwargs)
99

1010
from hypothesis import assume, given
11-
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared
11+
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite
1212

1313

1414
optional_dtypes = none() | shared_dtypes
@@ -148,10 +148,17 @@ def test_full(shape, fill_value, dtype):
148148
assert all(equal(a, asarray(fill_value, **kwargs))), "full() array did not equal the fill value"
149149

150150

151+
@composite
152+
def fill_values(draw):
153+
kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"))
154+
dtype = kw.get("dtype", None) or draw(shared_dtypes)
155+
return draw(xps.from_dtype(dtype))
156+
157+
151158
@given(
152159
x=xps.arrays(dtype=shared_dtypes, shape=shapes),
153-
fill_value=shared_dtypes.flatmap(xps.from_dtype),
154-
kw=kwargs(dtype=none() | shared_dtypes),
160+
fill_value=fill_values(),
161+
kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"),
155162
)
156163
def test_full_like(x, fill_value, kw):
157164
out = full_like(x, fill_value, **kw)

0 commit comments

Comments
 (0)