@@ -103,6 +103,8 @@ def default_filter(s: Scalar) -> bool:
103
103
"""
104
104
if isinstance (s , int ): # note bools are ints
105
105
return True
106
+ elif isinstance (s , complex ):
107
+ return default_filter (s .real ) and default_filter (s .imag )
106
108
else :
107
109
return math .isfinite (s ) and s != 0
108
110
@@ -247,7 +249,12 @@ def unary_assert_against_refimpl(
247
249
in_stype = dh .get_scalar_type (in_ .dtype )
248
250
if res_stype is None :
249
251
res_stype = in_stype
250
- m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
252
+ if res .dtype == xp .bool :
253
+ m , M = (None , None )
254
+ if res .dtype in dh .complex_dtypes :
255
+ m , M = dh .dtype_ranges [dh .dtype_components [res .dtype ]]
256
+ else :
257
+ m , M = dh .dtype_ranges [res .dtype ]
251
258
for idx in sh .ndindex (in_ .shape ):
252
259
scalar_i = in_stype (in_ [idx ])
253
260
if not filter_ (scalar_i ):
@@ -257,9 +264,13 @@ def unary_assert_against_refimpl(
257
264
except Exception :
258
265
continue
259
266
if res .dtype != xp .bool :
260
- assert m is not None and M is not None # for mypy
261
- if expected <= m or expected >= M :
262
- continue
267
+ if res .dtype in dh .complex_dtypes :
268
+ for component in [expected .real , expected .imag ]:
269
+ if component <= m or expected >= M :
270
+ continue
271
+ else :
272
+ if expected <= m or expected >= M :
273
+ continue
263
274
scalar_o = res_stype (res [idx ])
264
275
f_i = sh .fmt_idx ("x" , idx )
265
276
f_o = sh .fmt_idx ("out" , idx )
@@ -418,8 +429,11 @@ def __repr__(self):
418
429
419
430
420
431
def make_unary_params (
421
- elwise_func_name : str , dtypes_strat : st . SearchStrategy [DataType ]
432
+ elwise_func_name : str , dtypes : Sequence [DataType ]
422
433
) -> List [Param [UnaryParamContext ]]:
434
+ if hh .FILTER_UNDEFINED_DTYPES :
435
+ dtypes = [d for d in dtypes if not isinstance (d , xp ._UndefinedStub )]
436
+ dtypes_strat = st .sampled_from (dtypes )
423
437
strat = xps .arrays (dtype = dtypes_strat , shape = hh .shapes ())
424
438
func_ctx = UnaryParamContext (
425
439
func_name = elwise_func_name , func = getattr (xp , elwise_func_name ), strat = strat
@@ -633,7 +647,7 @@ def binary_param_assert_against_refimpl(
633
647
)
634
648
635
649
636
- @pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , xps .numeric_dtypes () ))
650
+ @pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
637
651
@given (data = st .data ())
638
652
def test_abs (ctx , data ):
639
653
x = data .draw (ctx .strat , label = "x" )
@@ -643,7 +657,10 @@ def test_abs(ctx, data):
643
657
644
658
out = ctx .func (x )
645
659
646
- ph .assert_dtype (ctx .func_name , x .dtype , out .dtype )
660
+ if x .dtype in dh .complex_dtypes :
661
+ assert out .dtype == dh .complex_components [x .dtype ]
662
+ else :
663
+ ph .assert_dtype (ctx .func_name , x .dtype , out .dtype )
647
664
ph .assert_shape (ctx .func_name , out .shape , x .shape )
648
665
unary_assert_against_refimpl (
649
666
ctx .func_name ,
@@ -783,7 +800,7 @@ def test_bitwise_left_shift(ctx, data):
783
800
784
801
785
802
@pytest .mark .parametrize (
786
- "ctx" , make_unary_params ("bitwise_invert" , boolean_and_all_integer_dtypes () )
803
+ "ctx" , make_unary_params ("bitwise_invert" , dh . bool_and_all_int_dtypes )
787
804
)
788
805
@given (data = st .data ())
789
806
def test_bitwise_invert (ctx , data ):
@@ -1187,9 +1204,7 @@ def test_multiply(ctx, data):
1187
1204
1188
1205
1189
1206
# TODO: clarify if uints are acceptable, adjust accordingly
1190
- @pytest .mark .parametrize (
1191
- "ctx" , make_unary_params ("negative" , xps .integer_dtypes () | xps .floating_dtypes ())
1192
- )
1207
+ @pytest .mark .parametrize ("ctx" , make_unary_params ("negative" , dh .numeric_dtypes ))
1193
1208
@given (data = st .data ())
1194
1209
def test_negative (ctx , data ):
1195
1210
x = data .draw (ctx .strat , label = "x" )
@@ -1226,7 +1241,7 @@ def test_not_equal(ctx, data):
1226
1241
)
1227
1242
1228
1243
1229
- @pytest .mark .parametrize ("ctx" , make_unary_params ("positive" , xps .numeric_dtypes () ))
1244
+ @pytest .mark .parametrize ("ctx" , make_unary_params ("positive" , dh .numeric_dtypes ))
1230
1245
@given (data = st .data ())
1231
1246
def test_positive (ctx , data ):
1232
1247
x = data .draw (ctx .strat , label = "x" )
@@ -1317,7 +1332,7 @@ def test_square(x):
1317
1332
ph .assert_dtype ("square" , x .dtype , out .dtype )
1318
1333
ph .assert_shape ("square" , out .shape , x .shape )
1319
1334
unary_assert_against_refimpl (
1320
- "square" , x , out , lambda s : s ** 2 , expr_template = "{}²={}"
1335
+ "square" , x , out , lambda s : s ** 2 , expr_template = "{}²={}"
1321
1336
)
1322
1337
1323
1338
0 commit comments