Skip to content

Commit ba6995e

Browse files
committed
Bound test_full fill values to default dtypes
1 parent d0f41dc commit ba6995e

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

array_api_tests/dtype_helpers.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from warnings import warn
12
from typing import NamedTuple
23

34
from . import _array_module as xp
5+
from ._array_module import _UndefinedStub
46

57

68
__all__ = [
@@ -16,6 +18,8 @@
1618
'is_int_dtype',
1719
'is_float_dtype',
1820
'dtype_ranges',
21+
'default_int',
22+
'default_float',
1923
'promotion_table',
2024
'dtype_nbits',
2125
'dtype_signed',
@@ -84,6 +88,18 @@ class MinMax(NamedTuple):
8488
}
8589

8690

91+
if isinstance(xp.asarray, _UndefinedStub):
92+
default_int = xp.int32
93+
default_float = xp.float32
94+
warn(
95+
'array module does not have attribute asarray. '
96+
'default int is assumed int32, default float is assumed float32'
97+
)
98+
else:
99+
default_int = xp.asarray(int()).dtype
100+
default_float = xp.asarray(float()).dtype
101+
102+
87103
_numeric_promotions = {
88104
# ints
89105
(xp.int8, xp.int8): xp.int8,

array_api_tests/test_creation_functions.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import xps
1313

1414
from hypothesis import assume, given
15-
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite
15+
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite, SearchStrategy
1616

1717

1818
int_range = integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)
@@ -128,10 +128,20 @@ def test_eye(n_rows, n_cols, k, dtype):
128128
assert a[i, j] == 0, "eye() did not produce a 0 off the diagonal"
129129

130130

131+
default_unsafe_dtypes = [xp.uint64]
132+
if dh.default_int == xp.int32:
133+
default_unsafe_dtypes.extend([xp.uint32, xp.int64])
134+
if dh.default_float == xp.float32:
135+
default_unsafe_dtypes.append(xp.float64)
136+
default_safe_scalar_dtypes: SearchStrategy = xps.scalar_dtypes().filter(
137+
lambda d: d not in default_unsafe_dtypes
138+
)
139+
140+
131141
@composite
132142
def full_fill_values(draw):
133143
kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_kw"))
134-
dtype = kw.get("dtype", None) or draw(xps.scalar_dtypes())
144+
dtype = kw.get("dtype", None) or draw(default_safe_scalar_dtypes)
135145
return draw(xps.from_dtype(dtype))
136146

137147

0 commit comments

Comments
 (0)