Skip to content

Commit 554cfd9

Browse files
committed
Apply iter_indices() logic to binary op/elwise tests
1 parent 2ab8a56 commit 554cfd9

File tree

1 file changed

+158
-83
lines changed

1 file changed

+158
-83
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 158 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3939
return xps.boolean_dtypes() | all_integer_dtypes()
4040

4141

42+
def isclose(n1: float, n2: float):
43+
if not (math.isfinite(n1) and math.isfinite(n2)):
44+
raise ValueError(f"{n1=} and {n1=}, but input must be finite")
45+
return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1)
46+
47+
4248
# When appropiate, this module tests operators alongside their respective
4349
# elementwise methods. We do this by parametrizing a generalised test method
4450
# with every relevant method and operator.
@@ -766,6 +772,7 @@ def test_divide(ctx, data):
766772
res = ctx.func(left, right)
767773

768774
assert_binary_param_dtype(ctx, left, right, res)
775+
assert_binary_param_shape(ctx, left, right, res)
769776
# There isn't much we can test here. The spec doesn't require any behavior
770777
# beyond the special cases, and indeed, there aren't many mathematical
771778
# properties of division that strictly hold for floating-point numbers. We
@@ -884,23 +891,38 @@ def test_floor_divide(ctx, data):
884891
res = ctx.func(left, right)
885892

886893
assert_binary_param_dtype(ctx, left, right, res)
887-
if not ctx.right_is_scalar:
888-
if dh.is_int_dtype(left.dtype):
889-
# The spec does not specify the behavior for division by 0 for integer
890-
# dtypes. A library may choose to raise an exception in this case, so
891-
# we avoid passing it in entirely.
892-
div = xp.divide(
893-
ah.asarray(left, dtype=xp.float64),
894-
ah.asarray(right, dtype=xp.float64),
894+
assert_binary_param_shape(ctx, left, right, res)
895+
scalar_type = dh.get_scalar_type(res.dtype)
896+
if ctx.right_is_scalar:
897+
for idx in sh.ndindex(res.shape):
898+
scalar_l = scalar_type(left[idx])
899+
expected = scalar_l // right
900+
scalar_o = scalar_type(res[idx])
901+
if not all(math.isfinite(n) for n in [scalar_l, right, scalar_o, expected]):
902+
continue
903+
f_l = sh.fmt_idx(ctx.left_sym, idx)
904+
f_o = sh.fmt_idx(ctx.res_name, idx)
905+
assert isclose(scalar_o, expected), (
906+
f"{f_o}={scalar_o}, but should be roughly ({f_l} // {right})={expected} "
907+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
908+
)
909+
else:
910+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
911+
scalar_l = scalar_type(left[l_idx])
912+
scalar_r = scalar_type(right[r_idx])
913+
expected = scalar_l // scalar_r
914+
scalar_o = scalar_type(res[o_idx])
915+
if not all(
916+
math.isfinite(n) for n in [scalar_l, scalar_r, scalar_o, expected]
917+
):
918+
continue
919+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
920+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
921+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
922+
assert isclose(scalar_o, expected), (
923+
f"{f_o}={scalar_o}, but should be roughly ({f_l} // {f_r})={expected} "
924+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
895925
)
896-
else:
897-
div = xp.divide(left, right)
898-
899-
# TODO: The spec doesn't clearly specify the behavior of floor_divide on
900-
# infinities. See https://github.com/data-apis/array-api/issues/199.
901-
finite = ah.isfinite(div)
902-
ah.assert_integral(res[finite])
903-
# TODO: Test the exact output for floor_divide.
904926

905927

906928
@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes()))
@@ -912,24 +934,37 @@ def test_greater(ctx, data):
912934
out = ctx.func(left, right)
913935

914936
assert_binary_param_dtype(ctx, left, right, out, xp.bool)
915-
if not ctx.right_is_scalar:
916-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
917-
shape = broadcast_shapes(left.shape, right.shape)
918-
ph.assert_shape(ctx.func_name, out.shape, shape)
919-
_left = xp.broadcast_to(left, shape)
920-
_right = xp.broadcast_to(right, shape)
921-
937+
assert_binary_param_shape(ctx, left, right, out)
938+
if ctx.right_is_scalar:
939+
scalar_type = dh.get_scalar_type(left.dtype)
940+
for idx in sh.ndindex(left.shape):
941+
scalar_l = scalar_type(left[idx])
942+
expected = scalar_l > right
943+
scalar_o = bool(out[idx])
944+
f_l = sh.fmt_idx(ctx.left_sym, idx)
945+
f_o = sh.fmt_idx(ctx.res_name, idx)
946+
assert scalar_o == expected, (
947+
f"{f_o}={scalar_o}, but should be ({f_l} > {right})={expected} "
948+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
949+
)
950+
else:
951+
# See test_equal note
922952
promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
923-
_left = ah.asarray(_left, dtype=promoted_dtype)
924-
_right = ah.asarray(_right, dtype=promoted_dtype)
925-
953+
_left = xp.astype(left, promoted_dtype)
954+
_right = xp.astype(right, promoted_dtype)
926955
scalar_type = dh.get_scalar_type(promoted_dtype)
927-
for idx in sh.ndindex(shape):
928-
out_idx = out[idx]
929-
x1_idx = _left[idx]
930-
x2_idx = _right[idx]
931-
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
932-
assert bool(out_idx) == (scalar_type(x1_idx) > scalar_type(x2_idx))
956+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape):
957+
scalar_l = scalar_type(_left[l_idx])
958+
scalar_r = scalar_type(_right[r_idx])
959+
expected = scalar_l > scalar_r
960+
scalar_o = bool(out[o_idx])
961+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
962+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
963+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
964+
assert scalar_o == expected, (
965+
f"{f_o}={scalar_o}, but should be ({f_l} > {f_r})={expected} "
966+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
967+
)
933968

934969

935970
@pytest.mark.parametrize(
@@ -943,25 +978,37 @@ def test_greater_equal(ctx, data):
943978
out = ctx.func(left, right)
944979

945980
assert_binary_param_dtype(ctx, left, right, out, xp.bool)
946-
if not ctx.right_is_scalar:
947-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
948-
949-
shape = broadcast_shapes(left.shape, right.shape)
950-
ph.assert_shape(ctx.func_name, out.shape, shape)
951-
_left = xp.broadcast_to(left, shape)
952-
_right = xp.broadcast_to(right, shape)
953-
981+
assert_binary_param_shape(ctx, left, right, out)
982+
if ctx.right_is_scalar:
983+
scalar_type = dh.get_scalar_type(left.dtype)
984+
for idx in sh.ndindex(left.shape):
985+
scalar_l = scalar_type(left[idx])
986+
expected = scalar_l >= right
987+
scalar_o = bool(out[idx])
988+
f_l = sh.fmt_idx(ctx.left_sym, idx)
989+
f_o = sh.fmt_idx(ctx.res_name, idx)
990+
assert scalar_o == expected, (
991+
f"{f_o}={scalar_o}, but should be ({f_l} >= {right})={expected} "
992+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
993+
)
994+
else:
995+
# See test_equal note
954996
promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
955-
_left = ah.asarray(_left, dtype=promoted_dtype)
956-
_right = ah.asarray(_right, dtype=promoted_dtype)
957-
997+
_left = xp.astype(left, promoted_dtype)
998+
_right = xp.astype(right, promoted_dtype)
958999
scalar_type = dh.get_scalar_type(promoted_dtype)
959-
for idx in sh.ndindex(shape):
960-
out_idx = out[idx]
961-
x1_idx = _left[idx]
962-
x2_idx = _right[idx]
963-
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
964-
assert bool(out_idx) == (scalar_type(x1_idx) >= scalar_type(x2_idx))
1000+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape):
1001+
scalar_l = scalar_type(_left[l_idx])
1002+
scalar_r = scalar_type(_right[r_idx])
1003+
expected = scalar_l >= scalar_r
1004+
scalar_o = bool(out[o_idx])
1005+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
1006+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
1007+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
1008+
assert scalar_o == expected, (
1009+
f"{f_o}={scalar_o}, but should be ({f_l} >= {f_r})={expected} "
1010+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1011+
)
9651012

9661013

9671014
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
@@ -1029,25 +1076,37 @@ def test_less(ctx, data):
10291076
out = ctx.func(left, right)
10301077

10311078
assert_binary_param_dtype(ctx, left, right, out, xp.bool)
1032-
if not ctx.right_is_scalar:
1033-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
1034-
1035-
shape = broadcast_shapes(left.shape, right.shape)
1036-
ph.assert_shape(ctx.func_name, out.shape, shape)
1037-
_left = xp.broadcast_to(left, shape)
1038-
_right = xp.broadcast_to(right, shape)
1039-
1079+
assert_binary_param_shape(ctx, left, right, out)
1080+
if ctx.right_is_scalar:
1081+
scalar_type = dh.get_scalar_type(left.dtype)
1082+
for idx in sh.ndindex(left.shape):
1083+
scalar_l = scalar_type(left[idx])
1084+
expected = scalar_l < right
1085+
scalar_o = bool(out[idx])
1086+
f_l = sh.fmt_idx(ctx.left_sym, idx)
1087+
f_o = sh.fmt_idx(ctx.res_name, idx)
1088+
assert scalar_o == expected, (
1089+
f"{f_o}={scalar_o}, but should be ({f_l} < {right})={expected} "
1090+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
1091+
)
1092+
else:
1093+
# See test_equal note
10401094
promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
1041-
_left = ah.asarray(_left, dtype=promoted_dtype)
1042-
_right = ah.asarray(_right, dtype=promoted_dtype)
1043-
1095+
_left = xp.astype(left, promoted_dtype)
1096+
_right = xp.astype(right, promoted_dtype)
10441097
scalar_type = dh.get_scalar_type(promoted_dtype)
1045-
for idx in sh.ndindex(shape):
1046-
x1_idx = _left[idx]
1047-
x2_idx = _right[idx]
1048-
out_idx = out[idx]
1049-
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
1050-
assert bool(out_idx) == (scalar_type(x1_idx) < scalar_type(x2_idx))
1098+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape):
1099+
scalar_l = scalar_type(_left[l_idx])
1100+
scalar_r = scalar_type(_right[r_idx])
1101+
expected = scalar_l < scalar_r
1102+
scalar_o = bool(out[o_idx])
1103+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
1104+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
1105+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
1106+
assert scalar_o == expected, (
1107+
f"{f_o}={scalar_o}, but should be ({f_l} < {f_r})={expected} "
1108+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1109+
)
10511110

10521111

10531112
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes()))
@@ -1059,25 +1118,37 @@ def test_less_equal(ctx, data):
10591118
out = ctx.func(left, right)
10601119

10611120
assert_binary_param_dtype(ctx, left, right, out, xp.bool)
1062-
if not ctx.right_is_scalar:
1063-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
1064-
1065-
shape = broadcast_shapes(left.shape, right.shape)
1066-
ph.assert_shape(ctx.func_name, out.shape, shape)
1067-
_left = xp.broadcast_to(left, shape)
1068-
_right = xp.broadcast_to(right, shape)
1069-
1121+
assert_binary_param_shape(ctx, left, right, out)
1122+
if ctx.right_is_scalar:
1123+
scalar_type = dh.get_scalar_type(left.dtype)
1124+
for idx in sh.ndindex(left.shape):
1125+
scalar_l = scalar_type(left[idx])
1126+
expected = scalar_l <= right
1127+
scalar_o = bool(out[idx])
1128+
f_l = sh.fmt_idx(ctx.left_sym, idx)
1129+
f_o = sh.fmt_idx(ctx.res_name, idx)
1130+
assert scalar_o == expected, (
1131+
f"{f_o}={scalar_o}, but should be ({f_l} <= {right})={expected} "
1132+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
1133+
)
1134+
else:
1135+
# See test_equal note
10701136
promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
1071-
_left = ah.asarray(_left, dtype=promoted_dtype)
1072-
_right = ah.asarray(_right, dtype=promoted_dtype)
1073-
1137+
_left = xp.astype(left, promoted_dtype)
1138+
_right = xp.astype(right, promoted_dtype)
10741139
scalar_type = dh.get_scalar_type(promoted_dtype)
1075-
for idx in sh.ndindex(shape):
1076-
x1_idx = _left[idx]
1077-
x2_idx = _right[idx]
1078-
out_idx = out[idx]
1079-
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
1080-
assert bool(out_idx) == (scalar_type(x1_idx) <= scalar_type(x2_idx))
1140+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape):
1141+
scalar_l = scalar_type(_left[l_idx])
1142+
scalar_r = scalar_type(_right[r_idx])
1143+
expected = scalar_l <= scalar_r
1144+
scalar_o = bool(out[o_idx])
1145+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
1146+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
1147+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
1148+
assert scalar_o == expected, (
1149+
f"{f_o}={scalar_o}, but should be ({f_l} <= {f_r})={expected} "
1150+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1151+
)
10811152

10821153

10831154
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -1204,6 +1275,7 @@ def test_multiply(ctx, data):
12041275
res = ctx.func(left, right)
12051276

12061277
assert_binary_param_dtype(ctx, left, right, res)
1278+
assert_binary_param_shape(ctx, left, right, res)
12071279
if not ctx.right_is_scalar:
12081280
# multiply is commutative
12091281
expected = ctx.func(right, left)
@@ -1308,6 +1380,7 @@ def test_pow(ctx, data):
13081380
reject()
13091381

13101382
assert_binary_param_dtype(ctx, left, right, res)
1383+
assert_binary_param_shape(ctx, left, right, res)
13111384
# There isn't much we can test here. The spec doesn't require any behavior
13121385
# beyond the special cases, and indeed, there aren't many mathematical
13131386
# properties of exponentiation that strictly hold for floating-point
@@ -1333,6 +1406,7 @@ def test_remainder(ctx, data):
13331406
res = ctx.func(left, right)
13341407

13351408
assert_binary_param_dtype(ctx, left, right, res)
1409+
assert_binary_param_shape(ctx, left, right, res)
13361410
# TODO: test results
13371411

13381412

@@ -1414,6 +1488,7 @@ def test_subtract(ctx, data):
14141488
reject()
14151489

14161490
assert_binary_param_dtype(ctx, left, right, res)
1491+
assert_binary_param_shape(ctx, left, right, res)
14171492
# TODO
14181493

14191494

0 commit comments

Comments
 (0)