Skip to content

Commit 05f2cf9

Browse files
committed
Skip testing complex dtypes in test_data_type_functions.py for now
1 parent 90e7837 commit 05f2cf9

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616
pytestmark = pytest.mark.ci
1717

1818

19+
# TODO: test with complex dtypes
20+
def non_complex_dtypes():
21+
return xps.boolean_dtypes() | xps.real_dtypes()
22+
23+
1924
def float32(n: Union[int, float]) -> float:
2025
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2126

2227

2328
@given(
24-
x_dtype=xps.scalar_dtypes(),
25-
dtype=xps.scalar_dtypes(),
29+
x_dtype=non_complex_dtypes(),
30+
dtype=non_complex_dtypes(),
2631
kw=hh.kwargs(copy=st.booleans()),
2732
data=st.data(),
2833
)
@@ -101,7 +106,7 @@ def test_broadcast_to(x, data):
101106
# TODO: test values
102107

103108

104-
@given(_from=xps.scalar_dtypes(), to=xps.scalar_dtypes(), data=st.data())
109+
@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data())
105110
def test_can_cast(_from, to, data):
106111
from_ = data.draw(
107112
st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_"
@@ -114,10 +119,12 @@ def test_can_cast(_from, to, data):
114119
if _from == xp.bool:
115120
expected = to == xp.bool
116121
else:
117-
for dtypes in [dh.all_int_dtypes, dh.float_dtypes]:
122+
same_family = None
123+
for dtypes in [dh.all_int_dtypes, dh.float_dtypes, dh.complex_dtypes]:
118124
if _from in dtypes:
119125
same_family = to in dtypes
120126
break
127+
assert same_family is not None # sanity check
121128
if same_family:
122129
from_min, from_max = dh.dtype_ranges[_from]
123130
to_min, to_max = dh.dtype_ranges[to]

0 commit comments

Comments
 (0)