@@ -62,7 +62,8 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
62
62
#
63
63
# By default, floating-point functions/methods are loosely asserted against. Use
64
64
# `strict_check=True` when they should be strictly asserted against, i.e.
65
- # when a function should return intergrals.
65
+ # when a function should return intergrals. Likewise, use `strict_check=False`
66
+ # when integer function/methods should be loosely asserted against.
66
67
67
68
68
69
def isclose (a : float , b : float , rel_tol : float = 0.25 , abs_tol : float = 1 ) -> bool :
@@ -92,7 +93,7 @@ def unary_assert_against_refimpl(
92
93
expr_template : Optional [str ] = None ,
93
94
res_stype : Optional [ScalarType ] = None ,
94
95
filter_ : Callable [[Scalar ], bool ] = default_filter ,
95
- strict_check : bool = False ,
96
+ strict_check : Optional [ bool ] = None ,
96
97
):
97
98
if in_ .shape != res .shape :
98
99
raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
@@ -108,7 +109,7 @@ def unary_assert_against_refimpl(
108
109
continue
109
110
try :
110
111
expected = refimpl (scalar_i )
111
- except OverflowError :
112
+ except Exception :
112
113
continue
113
114
if res .dtype != xp .bool :
114
115
assert m is not None and M is not None # for mypy
@@ -118,7 +119,7 @@ def unary_assert_against_refimpl(
118
119
f_i = sh .fmt_idx ("x" , idx )
119
120
f_o = sh .fmt_idx ("out" , idx )
120
121
expr = expr_template .format (f_i , expected )
121
- if not strict_check and dh .is_float_dtype (res .dtype ):
122
+ if strict_check == False or dh .is_float_dtype (res .dtype ):
122
123
assert isclose (scalar_o , expected ), (
123
124
f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
124
125
f"{ f_i } ={ scalar_i } "
@@ -142,7 +143,7 @@ def binary_assert_against_refimpl(
142
143
right_sym : str = "x2" ,
143
144
res_name : str = "out" ,
144
145
filter_ : Callable [[Scalar ], bool ] = default_filter ,
145
- strict_check : bool = False ,
146
+ strict_check : Optional [ bool ] = None ,
146
147
):
147
148
if expr_template is None :
148
149
expr_template = func_name + "({}, {})={}"
@@ -157,7 +158,7 @@ def binary_assert_against_refimpl(
157
158
continue
158
159
try :
159
160
expected = refimpl (scalar_l , scalar_r )
160
- except OverflowError :
161
+ except Exception :
161
162
continue
162
163
if res .dtype != xp .bool :
163
164
assert m is not None and M is not None # for mypy
@@ -168,7 +169,7 @@ def binary_assert_against_refimpl(
168
169
f_r = sh .fmt_idx (right_sym , r_idx )
169
170
f_o = sh .fmt_idx (res_name , o_idx )
170
171
expr = expr_template .format (f_l , f_r , expected )
171
- if not strict_check and dh .is_float_dtype (res .dtype ):
172
+ if strict_check == False or dh .is_float_dtype (res .dtype ):
172
173
assert isclose (scalar_o , expected ), (
173
174
f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
174
175
f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
@@ -384,11 +385,12 @@ def binary_param_assert_against_refimpl(
384
385
refimpl : Callable [[Scalar , Scalar ], Scalar ],
385
386
res_stype : Optional [ScalarType ] = None ,
386
387
filter_ : Callable [[Scalar ], bool ] = default_filter ,
387
- strict_check : bool = False ,
388
+ strict_check : Optional [ bool ] = None ,
388
389
):
389
390
expr_template = "({} " + op_sym + " {})={}"
390
391
if ctx .right_is_scalar :
391
- assert filter_ (right ) # sanity check
392
+ if filter_ (right ):
393
+ return # short-circuit here as there will be nothing to test
392
394
in_stype = dh .get_scalar_type (left .dtype )
393
395
if res_stype is None :
394
396
res_stype = in_stype
@@ -399,7 +401,7 @@ def binary_param_assert_against_refimpl(
399
401
continue
400
402
try :
401
403
expected = refimpl (scalar_l , right )
402
- except OverflowError :
404
+ except Exception :
403
405
continue
404
406
if left .dtype != xp .bool :
405
407
assert m is not None and M is not None # for mypy
@@ -409,7 +411,7 @@ def binary_param_assert_against_refimpl(
409
411
f_l = sh .fmt_idx (ctx .left_sym , idx )
410
412
f_o = sh .fmt_idx (ctx .res_name , idx )
411
413
expr = expr_template .format (f_l , right , expected )
412
- if not strict_check and dh .is_float_dtype (left .dtype ):
414
+ if strict_check == False or dh .is_float_dtype (res .dtype ):
413
415
assert isclose (scalar_o , expected ), (
414
416
f"{ f_o } ={ scalar_o } , but should be roughly { expr } "
415
417
f"[{ ctx .func_name } ()]\n "
@@ -704,16 +706,22 @@ def test_cosh(x):
704
706
def test_divide (ctx , data ):
705
707
left = data .draw (ctx .left_strat , label = ctx .left_sym )
706
708
right = data .draw (ctx .right_strat , label = ctx .right_sym )
709
+ if ctx .right_is_scalar :
710
+ assume
707
711
708
712
res = ctx .func (left , right )
709
713
710
714
binary_param_assert_dtype (ctx , left , right , res )
711
715
binary_param_assert_shape (ctx , left , right , res )
712
- # There isn't much we can test here. The spec doesn't require any behavior
713
- # beyond the special cases, and indeed, there aren't many mathematical
714
- # properties of division that strictly hold for floating-point numbers. We
715
- # could test that this does implement IEEE 754 division, but we don't yet
716
- # have those sorts in general for this module.
716
+ binary_param_assert_against_refimpl (
717
+ ctx ,
718
+ left ,
719
+ right ,
720
+ res ,
721
+ "/" ,
722
+ operator .truediv ,
723
+ filter_ = lambda s : math .isfinite (s ) and s != 0 ,
724
+ )
717
725
718
726
719
727
@pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , xps .scalar_dtypes ()))
@@ -836,17 +844,7 @@ def test_isfinite(x):
836
844
out = ah .isfinite (x )
837
845
ph .assert_dtype ("isfinite" , x .dtype , out .dtype , xp .bool )
838
846
ph .assert_shape ("isfinite" , out .shape , x .shape )
839
- if dh .is_int_dtype (x .dtype ):
840
- ah .assert_exactly_equal (out , ah .true (x .shape ))
841
- # Test that isfinite, isinf, and isnan are self-consistent.
842
- inf = ah .logical_or (xp .isinf (x ), ah .isnan (x ))
843
- ah .assert_exactly_equal (out , ah .logical_not (inf ))
844
-
845
- # Test the exact value by comparing to the math version
846
- if dh .is_float_dtype (x .dtype ):
847
- for idx in sh .ndindex (x .shape ):
848
- s = float (x [idx ])
849
- assert bool (out [idx ]) == math .isfinite (s )
847
+ unary_assert_against_refimpl ("isfinite" , x , out , math .isfinite , res_stype = bool )
850
848
851
849
852
850
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -949,9 +947,10 @@ def test_log10(x):
949
947
def test_logaddexp (x1 , x2 ):
950
948
out = xp .logaddexp (x1 , x2 )
951
949
ph .assert_dtype ("logaddexp" , [x1 .dtype , x2 .dtype ], out .dtype )
952
- # The spec doesn't require any behavior for this function. We could test
953
- # that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
954
- # don't have tests for this sort of thing for any functions yet.
950
+ ph .assert_result_shape ("logaddexp" , [x1 .shape , x2 .shape ], out .shape )
951
+ binary_assert_against_refimpl (
952
+ "logaddexp" , x1 , x2 , out , lambda l , r : math .log (math .exp (l ) + math .exp (r ))
953
+ )
955
954
956
955
957
956
@given (* hh .two_mutual_arrays ([xp .bool ]))
@@ -1078,11 +1077,9 @@ def test_pow(ctx, data):
1078
1077
1079
1078
binary_param_assert_dtype (ctx , left , right , res )
1080
1079
binary_param_assert_shape (ctx , left , right , res )
1081
- # There isn't much we can test here. The spec doesn't require any behavior
1082
- # beyond the special cases, and indeed, there aren't many mathematical
1083
- # properties of exponentiation that strictly hold for floating-point
1084
- # numbers. We could test that this does implement IEEE 754 pow, but we
1085
- # don't yet have those sorts in general for this module.
1080
+ binary_param_assert_against_refimpl (
1081
+ ctx , left , right , res , "**" , math .pow , strict_check = False
1082
+ )
1086
1083
1087
1084
1088
1085
@pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , xps .numeric_dtypes ()))
@@ -1110,28 +1107,14 @@ def test_round(x):
1110
1107
unary_assert_against_refimpl ("round" , x , out , round , strict_check = True )
1111
1108
1112
1109
1113
- @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
1110
+ @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (), elements = finite_kw ))
1114
1111
def test_sign (x ):
1115
1112
out = xp .sign (x )
1116
1113
ph .assert_dtype ("sign" , x .dtype , out .dtype )
1117
1114
ph .assert_shape ("sign" , out .shape , x .shape )
1118
- scalar_type = dh .get_scalar_type (out .dtype )
1119
- for idx in sh .ndindex (x .shape ):
1120
- scalar_x = scalar_type (x [idx ])
1121
- f_x = sh .fmt_idx ("x" , idx )
1122
- if math .isnan (scalar_x ):
1123
- continue
1124
- if scalar_x == 0 :
1125
- expected = 0
1126
- expr = f"{ f_x } =0"
1127
- else :
1128
- expected = 1 if scalar_x > 0 else - 1
1129
- expr = f"({ f_x } / |{ f_x } |)={ expected } "
1130
- scalar_o = scalar_type (out [idx ])
1131
- f_o = sh .fmt_idx ("out" , idx )
1132
- assert (
1133
- scalar_o == expected
1134
- ), f"{ f_o } ={ scalar_o } , but should be { expr } [sign()]\n { f_x } ={ scalar_x } "
1115
+ unary_assert_against_refimpl (
1116
+ "sign" , x , out , lambda s : math .copysign (1 , s ), filter_ = lambda s : s != 0
1117
+ )
1135
1118
1136
1119
1137
1120
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
0 commit comments