@@ -39,6 +39,12 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
39
39
return xps .boolean_dtypes () | all_integer_dtypes ()
40
40
41
41
42
+ def isclose (n1 : float , n2 : float ):
43
+ if not (math .isfinite (n1 ) and math .isfinite (n2 )):
44
+ raise ValueError (f"{ n1 = } and { n1 = } , but input must be finite" )
45
+ return math .isclose (n1 , n2 , rel_tol = 0.25 , abs_tol = 1 )
46
+
47
+
42
48
# When appropiate, this module tests operators alongside their respective
43
49
# elementwise methods. We do this by parametrizing a generalised test method
44
50
# with every relevant method and operator.
@@ -766,6 +772,7 @@ def test_divide(ctx, data):
766
772
res = ctx .func (left , right )
767
773
768
774
assert_binary_param_dtype (ctx , left , right , res )
775
+ assert_binary_param_shape (ctx , left , right , res )
769
776
# There isn't much we can test here. The spec doesn't require any behavior
770
777
# beyond the special cases, and indeed, there aren't many mathematical
771
778
# properties of division that strictly hold for floating-point numbers. We
@@ -884,23 +891,38 @@ def test_floor_divide(ctx, data):
884
891
res = ctx .func (left , right )
885
892
886
893
assert_binary_param_dtype (ctx , left , right , res )
887
- if not ctx .right_is_scalar :
888
- if dh .is_int_dtype (left .dtype ):
889
- # The spec does not specify the behavior for division by 0 for integer
890
- # dtypes. A library may choose to raise an exception in this case, so
891
- # we avoid passing it in entirely.
892
- div = xp .divide (
893
- ah .asarray (left , dtype = xp .float64 ),
894
- ah .asarray (right , dtype = xp .float64 ),
894
+ assert_binary_param_shape (ctx , left , right , res )
895
+ scalar_type = dh .get_scalar_type (res .dtype )
896
+ if ctx .right_is_scalar :
897
+ for idx in sh .ndindex (res .shape ):
898
+ scalar_l = scalar_type (left [idx ])
899
+ expected = scalar_l // right
900
+ scalar_o = scalar_type (res [idx ])
901
+ if not all (math .isfinite (n ) for n in [scalar_l , right , scalar_o , expected ]):
902
+ continue
903
+ f_l = sh .fmt_idx (ctx .left_sym , idx )
904
+ f_o = sh .fmt_idx (ctx .res_name , idx )
905
+ assert isclose (scalar_o , expected ), (
906
+ f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } // { right } )={ expected } "
907
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
908
+ )
909
+ else :
910
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
911
+ scalar_l = scalar_type (left [l_idx ])
912
+ scalar_r = scalar_type (right [r_idx ])
913
+ expected = scalar_l // scalar_r
914
+ scalar_o = scalar_type (res [o_idx ])
915
+ if not all (
916
+ math .isfinite (n ) for n in [scalar_l , scalar_r , scalar_o , expected ]
917
+ ):
918
+ continue
919
+ f_l = sh .fmt_idx (ctx .left_sym , l_idx )
920
+ f_r = sh .fmt_idx (ctx .right_sym , r_idx )
921
+ f_o = sh .fmt_idx (ctx .res_name , o_idx )
922
+ assert isclose (scalar_o , expected ), (
923
+ f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } // { f_r } )={ expected } "
924
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
895
925
)
896
- else :
897
- div = xp .divide (left , right )
898
-
899
- # TODO: The spec doesn't clearly specify the behavior of floor_divide on
900
- # infinities. See https://github.com/data-apis/array-api/issues/199.
901
- finite = ah .isfinite (div )
902
- ah .assert_integral (res [finite ])
903
- # TODO: Test the exact output for floor_divide.
904
926
905
927
906
928
@pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , xps .numeric_dtypes ()))
@@ -912,24 +934,37 @@ def test_greater(ctx, data):
912
934
out = ctx .func (left , right )
913
935
914
936
assert_binary_param_dtype (ctx , left , right , out , xp .bool )
915
- if not ctx .right_is_scalar :
916
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
917
- shape = broadcast_shapes (left .shape , right .shape )
918
- ph .assert_shape (ctx .func_name , out .shape , shape )
919
- _left = xp .broadcast_to (left , shape )
920
- _right = xp .broadcast_to (right , shape )
921
-
937
+ assert_binary_param_shape (ctx , left , right , out )
938
+ if ctx .right_is_scalar :
939
+ scalar_type = dh .get_scalar_type (left .dtype )
940
+ for idx in sh .ndindex (left .shape ):
941
+ scalar_l = scalar_type (left [idx ])
942
+ expected = scalar_l > right
943
+ scalar_o = bool (out [idx ])
944
+ f_l = sh .fmt_idx (ctx .left_sym , idx )
945
+ f_o = sh .fmt_idx (ctx .res_name , idx )
946
+ assert scalar_o == expected , (
947
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } > { right } )={ expected } "
948
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
949
+ )
950
+ else :
951
+ # See test_equal note
922
952
promoted_dtype = dh .promotion_table [left .dtype , right .dtype ]
923
- _left = ah .asarray (_left , dtype = promoted_dtype )
924
- _right = ah .asarray (_right , dtype = promoted_dtype )
925
-
953
+ _left = xp .astype (left , promoted_dtype )
954
+ _right = xp .astype (right , promoted_dtype )
926
955
scalar_type = dh .get_scalar_type (promoted_dtype )
927
- for idx in sh .ndindex (shape ):
928
- out_idx = out [idx ]
929
- x1_idx = _left [idx ]
930
- x2_idx = _right [idx ]
931
- assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
932
- assert bool (out_idx ) == (scalar_type (x1_idx ) > scalar_type (x2_idx ))
956
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , out .shape ):
957
+ scalar_l = scalar_type (_left [l_idx ])
958
+ scalar_r = scalar_type (_right [r_idx ])
959
+ expected = scalar_l > scalar_r
960
+ scalar_o = bool (out [o_idx ])
961
+ f_l = sh .fmt_idx (ctx .left_sym , l_idx )
962
+ f_r = sh .fmt_idx (ctx .right_sym , r_idx )
963
+ f_o = sh .fmt_idx (ctx .res_name , o_idx )
964
+ assert scalar_o == expected , (
965
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } > { f_r } )={ expected } "
966
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
967
+ )
933
968
934
969
935
970
@pytest .mark .parametrize (
@@ -943,25 +978,37 @@ def test_greater_equal(ctx, data):
943
978
out = ctx .func (left , right )
944
979
945
980
assert_binary_param_dtype (ctx , left , right , out , xp .bool )
946
- if not ctx .right_is_scalar :
947
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
948
-
949
- shape = broadcast_shapes (left .shape , right .shape )
950
- ph .assert_shape (ctx .func_name , out .shape , shape )
951
- _left = xp .broadcast_to (left , shape )
952
- _right = xp .broadcast_to (right , shape )
953
-
981
+ assert_binary_param_shape (ctx , left , right , out )
982
+ if ctx .right_is_scalar :
983
+ scalar_type = dh .get_scalar_type (left .dtype )
984
+ for idx in sh .ndindex (left .shape ):
985
+ scalar_l = scalar_type (left [idx ])
986
+ expected = scalar_l >= right
987
+ scalar_o = bool (out [idx ])
988
+ f_l = sh .fmt_idx (ctx .left_sym , idx )
989
+ f_o = sh .fmt_idx (ctx .res_name , idx )
990
+ assert scalar_o == expected , (
991
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } >= { right } )={ expected } "
992
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
993
+ )
994
+ else :
995
+ # See test_equal note
954
996
promoted_dtype = dh .promotion_table [left .dtype , right .dtype ]
955
- _left = ah .asarray (_left , dtype = promoted_dtype )
956
- _right = ah .asarray (_right , dtype = promoted_dtype )
957
-
997
+ _left = xp .astype (left , promoted_dtype )
998
+ _right = xp .astype (right , promoted_dtype )
958
999
scalar_type = dh .get_scalar_type (promoted_dtype )
959
- for idx in sh .ndindex (shape ):
960
- out_idx = out [idx ]
961
- x1_idx = _left [idx ]
962
- x2_idx = _right [idx ]
963
- assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
964
- assert bool (out_idx ) == (scalar_type (x1_idx ) >= scalar_type (x2_idx ))
1000
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , out .shape ):
1001
+ scalar_l = scalar_type (_left [l_idx ])
1002
+ scalar_r = scalar_type (_right [r_idx ])
1003
+ expected = scalar_l >= scalar_r
1004
+ scalar_o = bool (out [o_idx ])
1005
+ f_l = sh .fmt_idx (ctx .left_sym , l_idx )
1006
+ f_r = sh .fmt_idx (ctx .right_sym , r_idx )
1007
+ f_o = sh .fmt_idx (ctx .res_name , o_idx )
1008
+ assert scalar_o == expected , (
1009
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } >= { f_r } )={ expected } "
1010
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
1011
+ )
965
1012
966
1013
967
1014
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -1029,25 +1076,37 @@ def test_less(ctx, data):
1029
1076
out = ctx .func (left , right )
1030
1077
1031
1078
assert_binary_param_dtype (ctx , left , right , out , xp .bool )
1032
- if not ctx .right_is_scalar :
1033
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
1034
-
1035
- shape = broadcast_shapes (left .shape , right .shape )
1036
- ph .assert_shape (ctx .func_name , out .shape , shape )
1037
- _left = xp .broadcast_to (left , shape )
1038
- _right = xp .broadcast_to (right , shape )
1039
-
1079
+ assert_binary_param_shape (ctx , left , right , out )
1080
+ if ctx .right_is_scalar :
1081
+ scalar_type = dh .get_scalar_type (left .dtype )
1082
+ for idx in sh .ndindex (left .shape ):
1083
+ scalar_l = scalar_type (left [idx ])
1084
+ expected = scalar_l < right
1085
+ scalar_o = bool (out [idx ])
1086
+ f_l = sh .fmt_idx (ctx .left_sym , idx )
1087
+ f_o = sh .fmt_idx (ctx .res_name , idx )
1088
+ assert scalar_o == expected , (
1089
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } < { right } )={ expected } "
1090
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
1091
+ )
1092
+ else :
1093
+ # See test_equal note
1040
1094
promoted_dtype = dh .promotion_table [left .dtype , right .dtype ]
1041
- _left = ah .asarray (_left , dtype = promoted_dtype )
1042
- _right = ah .asarray (_right , dtype = promoted_dtype )
1043
-
1095
+ _left = xp .astype (left , promoted_dtype )
1096
+ _right = xp .astype (right , promoted_dtype )
1044
1097
scalar_type = dh .get_scalar_type (promoted_dtype )
1045
- for idx in sh .ndindex (shape ):
1046
- x1_idx = _left [idx ]
1047
- x2_idx = _right [idx ]
1048
- out_idx = out [idx ]
1049
- assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
1050
- assert bool (out_idx ) == (scalar_type (x1_idx ) < scalar_type (x2_idx ))
1098
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , out .shape ):
1099
+ scalar_l = scalar_type (_left [l_idx ])
1100
+ scalar_r = scalar_type (_right [r_idx ])
1101
+ expected = scalar_l < scalar_r
1102
+ scalar_o = bool (out [o_idx ])
1103
+ f_l = sh .fmt_idx (ctx .left_sym , l_idx )
1104
+ f_r = sh .fmt_idx (ctx .right_sym , r_idx )
1105
+ f_o = sh .fmt_idx (ctx .res_name , o_idx )
1106
+ assert scalar_o == expected , (
1107
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } < { f_r } )={ expected } "
1108
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
1109
+ )
1051
1110
1052
1111
1053
1112
@pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , xps .numeric_dtypes ()))
@@ -1059,25 +1118,37 @@ def test_less_equal(ctx, data):
1059
1118
out = ctx .func (left , right )
1060
1119
1061
1120
assert_binary_param_dtype (ctx , left , right , out , xp .bool )
1062
- if not ctx .right_is_scalar :
1063
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
1064
-
1065
- shape = broadcast_shapes (left .shape , right .shape )
1066
- ph .assert_shape (ctx .func_name , out .shape , shape )
1067
- _left = xp .broadcast_to (left , shape )
1068
- _right = xp .broadcast_to (right , shape )
1069
-
1121
+ assert_binary_param_shape (ctx , left , right , out )
1122
+ if ctx .right_is_scalar :
1123
+ scalar_type = dh .get_scalar_type (left .dtype )
1124
+ for idx in sh .ndindex (left .shape ):
1125
+ scalar_l = scalar_type (left [idx ])
1126
+ expected = scalar_l <= right
1127
+ scalar_o = bool (out [idx ])
1128
+ f_l = sh .fmt_idx (ctx .left_sym , idx )
1129
+ f_o = sh .fmt_idx (ctx .res_name , idx )
1130
+ assert scalar_o == expected , (
1131
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } <= { right } )={ expected } "
1132
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
1133
+ )
1134
+ else :
1135
+ # See test_equal note
1070
1136
promoted_dtype = dh .promotion_table [left .dtype , right .dtype ]
1071
- _left = ah .asarray (_left , dtype = promoted_dtype )
1072
- _right = ah .asarray (_right , dtype = promoted_dtype )
1073
-
1137
+ _left = xp .astype (left , promoted_dtype )
1138
+ _right = xp .astype (right , promoted_dtype )
1074
1139
scalar_type = dh .get_scalar_type (promoted_dtype )
1075
- for idx in sh .ndindex (shape ):
1076
- x1_idx = _left [idx ]
1077
- x2_idx = _right [idx ]
1078
- out_idx = out [idx ]
1079
- assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
1080
- assert bool (out_idx ) == (scalar_type (x1_idx ) <= scalar_type (x2_idx ))
1140
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , out .shape ):
1141
+ scalar_l = scalar_type (_left [l_idx ])
1142
+ scalar_r = scalar_type (_right [r_idx ])
1143
+ expected = scalar_l <= scalar_r
1144
+ scalar_o = bool (out [o_idx ])
1145
+ f_l = sh .fmt_idx (ctx .left_sym , l_idx )
1146
+ f_r = sh .fmt_idx (ctx .right_sym , r_idx )
1147
+ f_o = sh .fmt_idx (ctx .res_name , o_idx )
1148
+ assert scalar_o == expected , (
1149
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } <= { f_r } )={ expected } "
1150
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
1151
+ )
1081
1152
1082
1153
1083
1154
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -1204,6 +1275,7 @@ def test_multiply(ctx, data):
1204
1275
res = ctx .func (left , right )
1205
1276
1206
1277
assert_binary_param_dtype (ctx , left , right , res )
1278
+ assert_binary_param_shape (ctx , left , right , res )
1207
1279
if not ctx .right_is_scalar :
1208
1280
# multiply is commutative
1209
1281
expected = ctx .func (right , left )
@@ -1308,6 +1380,7 @@ def test_pow(ctx, data):
1308
1380
reject ()
1309
1381
1310
1382
assert_binary_param_dtype (ctx , left , right , res )
1383
+ assert_binary_param_shape (ctx , left , right , res )
1311
1384
# There isn't much we can test here. The spec doesn't require any behavior
1312
1385
# beyond the special cases, and indeed, there aren't many mathematical
1313
1386
# properties of exponentiation that strictly hold for floating-point
@@ -1333,6 +1406,7 @@ def test_remainder(ctx, data):
1333
1406
res = ctx .func (left , right )
1334
1407
1335
1408
assert_binary_param_dtype (ctx , left , right , res )
1409
+ assert_binary_param_shape (ctx , left , right , res )
1336
1410
# TODO: test results
1337
1411
1338
1412
@@ -1414,6 +1488,7 @@ def test_subtract(ctx, data):
1414
1488
reject ()
1415
1489
1416
1490
assert_binary_param_dtype (ctx , left , right , res )
1491
+ assert_binary_param_shape (ctx , left , right , res )
1417
1492
# TODO
1418
1493
1419
1494
0 commit comments