Skip to content

Commit 91026c0

Browse files
authored
Merge pull request #82 from honno/dtype-funcs
Test dtype functions
2 parents cd941c9 + 65dc8c1 commit 91026c0

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

array_api_tests/dtype_helpers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
8686

8787

8888
class MinMax(NamedTuple):
89-
min: int
90-
max: int
89+
min: Union[int, float]
90+
max: Union[int, float]
9191

9292

9393
dtype_ranges = {
@@ -99,6 +99,8 @@ class MinMax(NamedTuple):
9999
xp.uint16: MinMax(0, +65_535),
100100
xp.uint32: MinMax(0, +4_294_967_295),
101101
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
102+
xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38),
103+
xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308),
102104
}
103105

104106
dtype_nbits = {

array_api_tests/test_data_type_functions.py

+119
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,131 @@
1+
import struct
2+
from typing import Union
3+
14
import pytest
25
from hypothesis import given
6+
from hypothesis import strategies as st
37

48
from . import _array_module as xp
59
from . import dtype_helpers as dh
610
from . import hypothesis_helpers as hh
711
from . import pytest_helpers as ph
12+
from . import xps
13+
from .algos import broadcast_shapes
814
from .typing import DataType
915

16+
pytestmark = pytest.mark.ci
17+
18+
19+
def float32(n: Union[int, float]) -> float:
20+
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
21+
22+
23+
@given(
24+
x_dtype=xps.scalar_dtypes(),
25+
dtype=xps.scalar_dtypes(),
26+
kw=hh.kwargs(copy=st.booleans()),
27+
data=st.data(),
28+
)
29+
def test_astype(x_dtype, dtype, kw, data):
30+
if xp.bool in (x_dtype, dtype):
31+
elements_strat = xps.from_dtype(x_dtype)
32+
else:
33+
m1, M1 = dh.dtype_ranges[x_dtype]
34+
m2, M2 = dh.dtype_ranges[dtype]
35+
if dh.is_int_dtype(x_dtype):
36+
cast = int
37+
elif x_dtype == xp.float32:
38+
cast = float32
39+
else:
40+
cast = float
41+
min_value = cast(max(m1, m2))
42+
max_value = cast(min(M1, M2))
43+
elements_strat = xps.from_dtype(
44+
x_dtype,
45+
min_value=min_value,
46+
max_value=max_value,
47+
allow_nan=False,
48+
allow_infinity=False,
49+
)
50+
x = data.draw(
51+
xps.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
52+
)
53+
54+
out = xp.astype(x, dtype, **kw)
55+
56+
ph.assert_kw_dtype("astype", dtype, out.dtype)
57+
ph.assert_shape("astype", out.shape, x.shape)
58+
# TODO: test values
59+
# TODO: test copy
60+
61+
62+
@given(
63+
shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes), data=st.data()
64+
)
65+
def test_broadcast_arrays(shapes, data):
66+
arrays = []
67+
for c, shape in enumerate(shapes, 1):
68+
x = data.draw(xps.arrays(dtype=xps.scalar_dtypes(), shape=shape), label=f"x{c}")
69+
arrays.append(x)
70+
71+
out = xp.broadcast_arrays(*arrays)
72+
73+
out_shape = broadcast_shapes(*shapes)
74+
for i, x in enumerate(arrays):
75+
ph.assert_dtype(
76+
"broadcast_arrays", x.dtype, out[i].dtype, repr_name=f"out[{i}].dtype"
77+
)
78+
ph.assert_result_shape(
79+
"broadcast_arrays",
80+
shapes,
81+
out[i].shape,
82+
out_shape,
83+
repr_name=f"out[{i}].shape",
84+
)
85+
# TODO: test values
86+
87+
88+
@given(x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data())
89+
def test_broadcast_to(x, data):
90+
shape = data.draw(
91+
hh.mutually_broadcastable_shapes(1, base_shape=x.shape)
92+
.map(lambda S: S[0])
93+
.filter(lambda s: broadcast_shapes(x.shape, s) == s),
94+
label="shape",
95+
)
96+
97+
out = xp.broadcast_to(x, shape)
98+
99+
ph.assert_dtype("broadcast_to", x.dtype, out.dtype)
100+
ph.assert_shape("broadcast_to", out.shape, shape)
101+
# TODO: test values
102+
103+
104+
@given(_from=xps.scalar_dtypes(), to=xps.scalar_dtypes(), data=st.data())
105+
def test_can_cast(_from, to, data):
106+
from_ = data.draw(
107+
st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_"
108+
)
109+
110+
out = xp.can_cast(from_, to)
111+
112+
f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
113+
assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}"
114+
if _from == xp.bool:
115+
expected = to == xp.bool
116+
else:
117+
for dtypes in [dh.all_int_dtypes, dh.float_dtypes]:
118+
if _from in dtypes:
119+
same_family = to in dtypes
120+
break
121+
if same_family:
122+
from_min, from_max = dh.dtype_ranges[_from]
123+
to_min, to_max = dh.dtype_ranges[to]
124+
expected = from_min >= to_min and from_max <= to_max
125+
else:
126+
expected = False
127+
assert out == expected, f"{out=}, but should be {expected} {f_func}"
128+
10129

11130
def make_dtype_id(dtype: DataType) -> str:
12131
return dh.dtype_to_name[dtype]

0 commit comments

Comments
 (0)