16
16
pytestmark = pytest .mark .ci
17
17
18
18
19
+ # TODO: test with complex dtypes
20
+ def non_complex_dtypes ():
21
+ return xps .boolean_dtypes () | xps .real_dtypes ()
22
+
23
+
19
24
def float32 (n : Union [int , float ]) -> float :
20
25
return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
21
26
22
27
23
28
@given (
24
- x_dtype = xps . scalar_dtypes (),
25
- dtype = xps . scalar_dtypes (),
29
+ x_dtype = non_complex_dtypes (),
30
+ dtype = non_complex_dtypes (),
26
31
kw = hh .kwargs (copy = st .booleans ()),
27
32
data = st .data (),
28
33
)
@@ -101,7 +106,7 @@ def test_broadcast_to(x, data):
101
106
# TODO: test values
102
107
103
108
104
- @given (_from = xps . scalar_dtypes (), to = xps . scalar_dtypes (), data = st .data ())
109
+ @given (_from = non_complex_dtypes (), to = non_complex_dtypes (), data = st .data ())
105
110
def test_can_cast (_from , to , data ):
106
111
from_ = data .draw (
107
112
st .just (_from ) | xps .arrays (dtype = _from , shape = hh .shapes ()), label = "from_"
@@ -114,10 +119,12 @@ def test_can_cast(_from, to, data):
114
119
if _from == xp .bool :
115
120
expected = to == xp .bool
116
121
else :
117
- for dtypes in [dh .all_int_dtypes , dh .float_dtypes ]:
122
+ same_family = None
123
+ for dtypes in [dh .all_int_dtypes , dh .float_dtypes , dh .complex_dtypes ]:
118
124
if _from in dtypes :
119
125
same_family = to in dtypes
120
126
break
127
+ assert same_family is not None # sanity check
121
128
if same_family :
122
129
from_min , from_max = dh .dtype_ranges [_from ]
123
130
to_min , to_max = dh .dtype_ranges [to ]
0 commit comments