Skip to content

Commit e72184e

Browse files
committed
Refactor remaining elwise/op tests
1 parent 9d1f4da commit e72184e

File tree

1 file changed

+36
-53
lines changed

1 file changed

+36
-53
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

+36-53
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
6262
#
6363
# By default, floating-point functions/methods are loosely asserted against. Use
6464
# `strict_check=True` when they should be strictly asserted against, i.e.
65-
# when a function should return intergrals.
65+
# when a function should return intergrals. Likewise, use `strict_check=False`
66+
# when integer function/methods should be loosely asserted against.
6667

6768

6869
def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
@@ -92,7 +93,7 @@ def unary_assert_against_refimpl(
9293
expr_template: Optional[str] = None,
9394
res_stype: Optional[ScalarType] = None,
9495
filter_: Callable[[Scalar], bool] = default_filter,
95-
strict_check: bool = False,
96+
strict_check: Optional[bool] = None,
9697
):
9798
if in_.shape != res.shape:
9899
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
@@ -108,7 +109,7 @@ def unary_assert_against_refimpl(
108109
continue
109110
try:
110111
expected = refimpl(scalar_i)
111-
except OverflowError:
112+
except Exception:
112113
continue
113114
if res.dtype != xp.bool:
114115
assert m is not None and M is not None # for mypy
@@ -118,7 +119,7 @@ def unary_assert_against_refimpl(
118119
f_i = sh.fmt_idx("x", idx)
119120
f_o = sh.fmt_idx("out", idx)
120121
expr = expr_template.format(f_i, expected)
121-
if not strict_check and dh.is_float_dtype(res.dtype):
122+
if strict_check == False or dh.is_float_dtype(res.dtype):
122123
assert isclose(scalar_o, expected), (
123124
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
124125
f"{f_i}={scalar_i}"
@@ -142,7 +143,7 @@ def binary_assert_against_refimpl(
142143
right_sym: str = "x2",
143144
res_name: str = "out",
144145
filter_: Callable[[Scalar], bool] = default_filter,
145-
strict_check: bool = False,
146+
strict_check: Optional[bool] = None,
146147
):
147148
if expr_template is None:
148149
expr_template = func_name + "({}, {})={}"
@@ -157,7 +158,7 @@ def binary_assert_against_refimpl(
157158
continue
158159
try:
159160
expected = refimpl(scalar_l, scalar_r)
160-
except OverflowError:
161+
except Exception:
161162
continue
162163
if res.dtype != xp.bool:
163164
assert m is not None and M is not None # for mypy
@@ -168,7 +169,7 @@ def binary_assert_against_refimpl(
168169
f_r = sh.fmt_idx(right_sym, r_idx)
169170
f_o = sh.fmt_idx(res_name, o_idx)
170171
expr = expr_template.format(f_l, f_r, expected)
171-
if not strict_check and dh.is_float_dtype(res.dtype):
172+
if strict_check == False or dh.is_float_dtype(res.dtype):
172173
assert isclose(scalar_o, expected), (
173174
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
174175
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
@@ -384,11 +385,12 @@ def binary_param_assert_against_refimpl(
384385
refimpl: Callable[[Scalar, Scalar], Scalar],
385386
res_stype: Optional[ScalarType] = None,
386387
filter_: Callable[[Scalar], bool] = default_filter,
387-
strict_check: bool = False,
388+
strict_check: Optional[bool] = None,
388389
):
389390
expr_template = "({} " + op_sym + " {})={}"
390391
if ctx.right_is_scalar:
391-
assert filter_(right) # sanity check
392+
if filter_(right):
393+
return # short-circuit here as there will be nothing to test
392394
in_stype = dh.get_scalar_type(left.dtype)
393395
if res_stype is None:
394396
res_stype = in_stype
@@ -399,7 +401,7 @@ def binary_param_assert_against_refimpl(
399401
continue
400402
try:
401403
expected = refimpl(scalar_l, right)
402-
except OverflowError:
404+
except Exception:
403405
continue
404406
if left.dtype != xp.bool:
405407
assert m is not None and M is not None # for mypy
@@ -409,7 +411,7 @@ def binary_param_assert_against_refimpl(
409411
f_l = sh.fmt_idx(ctx.left_sym, idx)
410412
f_o = sh.fmt_idx(ctx.res_name, idx)
411413
expr = expr_template.format(f_l, right, expected)
412-
if not strict_check and dh.is_float_dtype(left.dtype):
414+
if strict_check == False or dh.is_float_dtype(res.dtype):
413415
assert isclose(scalar_o, expected), (
414416
f"{f_o}={scalar_o}, but should be roughly {expr} "
415417
f"[{ctx.func_name}()]\n"
@@ -704,16 +706,22 @@ def test_cosh(x):
704706
def test_divide(ctx, data):
705707
left = data.draw(ctx.left_strat, label=ctx.left_sym)
706708
right = data.draw(ctx.right_strat, label=ctx.right_sym)
709+
if ctx.right_is_scalar:
710+
assume
707711

708712
res = ctx.func(left, right)
709713

710714
binary_param_assert_dtype(ctx, left, right, res)
711715
binary_param_assert_shape(ctx, left, right, res)
712-
# There isn't much we can test here. The spec doesn't require any behavior
713-
# beyond the special cases, and indeed, there aren't many mathematical
714-
# properties of division that strictly hold for floating-point numbers. We
715-
# could test that this does implement IEEE 754 division, but we don't yet
716-
# have those sorts in general for this module.
716+
binary_param_assert_against_refimpl(
717+
ctx,
718+
left,
719+
right,
720+
res,
721+
"/",
722+
operator.truediv,
723+
filter_=lambda s: math.isfinite(s) and s != 0,
724+
)
717725

718726

719727
@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes()))
@@ -836,17 +844,7 @@ def test_isfinite(x):
836844
out = ah.isfinite(x)
837845
ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool)
838846
ph.assert_shape("isfinite", out.shape, x.shape)
839-
if dh.is_int_dtype(x.dtype):
840-
ah.assert_exactly_equal(out, ah.true(x.shape))
841-
# Test that isfinite, isinf, and isnan are self-consistent.
842-
inf = ah.logical_or(xp.isinf(x), ah.isnan(x))
843-
ah.assert_exactly_equal(out, ah.logical_not(inf))
844-
845-
# Test the exact value by comparing to the math version
846-
if dh.is_float_dtype(x.dtype):
847-
for idx in sh.ndindex(x.shape):
848-
s = float(x[idx])
849-
assert bool(out[idx]) == math.isfinite(s)
847+
unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool)
850848

851849

852850
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
@@ -949,9 +947,10 @@ def test_log10(x):
949947
def test_logaddexp(x1, x2):
950948
out = xp.logaddexp(x1, x2)
951949
ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype)
952-
# The spec doesn't require any behavior for this function. We could test
953-
# that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
954-
# don't have tests for this sort of thing for any functions yet.
950+
ph.assert_result_shape("logaddexp", [x1.shape, x2.shape], out.shape)
951+
binary_assert_against_refimpl(
952+
"logaddexp", x1, x2, out, lambda l, r: math.log(math.exp(l) + math.exp(r))
953+
)
955954

956955

957956
@given(*hh.two_mutual_arrays([xp.bool]))
@@ -1078,11 +1077,9 @@ def test_pow(ctx, data):
10781077

10791078
binary_param_assert_dtype(ctx, left, right, res)
10801079
binary_param_assert_shape(ctx, left, right, res)
1081-
# There isn't much we can test here. The spec doesn't require any behavior
1082-
# beyond the special cases, and indeed, there aren't many mathematical
1083-
# properties of exponentiation that strictly hold for floating-point
1084-
# numbers. We could test that this does implement IEEE 754 pow, but we
1085-
# don't yet have those sorts in general for this module.
1080+
binary_param_assert_against_refimpl(
1081+
ctx, left, right, res, "**", math.pow, strict_check=False
1082+
)
10861083

10871084

10881085
@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes()))
@@ -1110,28 +1107,14 @@ def test_round(x):
11101107
unary_assert_against_refimpl("round", x, out, round, strict_check=True)
11111108

11121109

1113-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
1110+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw))
11141111
def test_sign(x):
11151112
out = xp.sign(x)
11161113
ph.assert_dtype("sign", x.dtype, out.dtype)
11171114
ph.assert_shape("sign", out.shape, x.shape)
1118-
scalar_type = dh.get_scalar_type(out.dtype)
1119-
for idx in sh.ndindex(x.shape):
1120-
scalar_x = scalar_type(x[idx])
1121-
f_x = sh.fmt_idx("x", idx)
1122-
if math.isnan(scalar_x):
1123-
continue
1124-
if scalar_x == 0:
1125-
expected = 0
1126-
expr = f"{f_x}=0"
1127-
else:
1128-
expected = 1 if scalar_x > 0 else -1
1129-
expr = f"({f_x} / |{f_x}|)={expected}"
1130-
scalar_o = scalar_type(out[idx])
1131-
f_o = sh.fmt_idx("out", idx)
1132-
assert (
1133-
scalar_o == expected
1134-
), f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}"
1115+
unary_assert_against_refimpl(
1116+
"sign", x, out, lambda s: math.copysign(1, s), filter_=lambda s: s != 0
1117+
)
11351118

11361119

11371120
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)