@@ -39,7 +39,7 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
39
39
return xps .boolean_dtypes () | all_integer_dtypes ()
40
40
41
41
42
- def isclose (n1 : float , n2 : float ):
42
+ def isclose (n1 : Union [ int , float ] , n2 : Union [ int , float ] ):
43
43
if not (math .isfinite (n1 ) and math .isfinite (n2 )):
44
44
raise ValueError (f"{ n1 = } and { n1 = } , but input must be finite" )
45
45
return math .isclose (n1 , n2 , rel_tol = 0.25 , abs_tol = 1 )
@@ -1394,20 +1394,45 @@ def test_remainder(ctx, data):
1394
1394
left = data .draw (ctx .left_strat , label = ctx .left_sym )
1395
1395
right = data .draw (ctx .right_strat , label = ctx .right_sym )
1396
1396
if ctx .right_is_scalar :
1397
- out_dtype = left . dtype
1397
+ assume ( right != 0 )
1398
1398
else :
1399
- out_dtype = dh .result_type (left .dtype , right .dtype )
1400
- if dh .is_int_dtype (out_dtype ):
1401
- if ctx .right_is_scalar :
1402
- assume (right != 0 )
1403
- else :
1404
- assume (not ah .any (right == 0 ))
1399
+ assume (not ah .any (right == 0 ))
1405
1400
1406
1401
res = ctx .func (left , right )
1407
1402
1408
1403
assert_binary_param_dtype (ctx , left , right , res )
1409
1404
assert_binary_param_shape (ctx , left , right , res )
1410
- # TODO: test results
1405
+ scalar_type = dh .get_scalar_type (res .dtype )
1406
+ if ctx .right_is_scalar :
1407
+ for idx in sh .ndindex (res .shape ):
1408
+ scalar_l = scalar_type (left [idx ])
1409
+ expected = scalar_l % right
1410
+ scalar_o = scalar_type (res [idx ])
1411
+ if not all (math .isfinite (n ) for n in [scalar_l , right , scalar_o , expected ]):
1412
+ continue
1413
+ f_l = sh .fmt_idx (ctx .left_sym , idx )
1414
+ f_o = sh .fmt_idx (ctx .res_name , idx )
1415
+ assert isclose (scalar_o , expected ), (
1416
+ f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } % { right } )={ expected } "
1417
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
1418
+ )
1419
+ else :
1420
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
1421
+ scalar_l = scalar_type (left [l_idx ])
1422
+ scalar_r = scalar_type (right [r_idx ])
1423
+ expected = scalar_l % scalar_r
1424
+ scalar_o = scalar_type (res [o_idx ])
1425
+ if not all (
1426
+ math .isfinite (n ) for n in [scalar_l , scalar_r , scalar_o , expected ]
1427
+ ):
1428
+ continue
1429
+ f_l = sh .fmt_idx (ctx .left_sym , l_idx )
1430
+ f_r = sh .fmt_idx (ctx .right_sym , r_idx )
1431
+ f_o = sh .fmt_idx (ctx .res_name , o_idx )
1432
+ assert isclose (scalar_o , expected ), (
1433
+ f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } % { f_r } )={ expected } "
1434
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
1435
+ )
1411
1436
1412
1437
1413
1438
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
0 commit comments