Skip to content

Commit 06a1944

Browse files
committed
Values testing for test_add
1 parent 0979dea commit 06a1944

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,37 @@ def test_add(ctx, data):
310310

311311
assert_binary_param_dtype(ctx, left, right, res)
312312
assert_binary_param_shape(ctx, left, right, res)
313-
if not ctx.right_is_scalar:
314-
# add is commutative
315-
expected = ctx.func(right, left)
316-
ah.assert_exactly_equal(res, expected)
313+
m, M = dh.dtype_ranges[res.dtype]
314+
scalar_type = dh.get_scalar_type(res.dtype)
315+
if ctx.right_is_scalar:
316+
for idx in sh.ndindex(res.shape):
317+
scalar_l = scalar_type(left[idx])
318+
expected = scalar_l + right
319+
if not math.isfinite(expected) or expected <= m or expected >= M:
320+
continue
321+
scalar_o = scalar_type(res[idx])
322+
f_l = sh.fmt_idx(ctx.left_sym, idx)
323+
f_o = sh.fmt_idx(ctx.res_name, idx)
324+
assert isclose(scalar_o, expected), (
325+
f"{f_o}={scalar_o}, but should be roughly ({f_l} + {right})={expected} "
326+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
327+
)
328+
else:
329+
ph.assert_array(ctx.func_name, res, ctx.func(right, left)) # cumulative
330+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
331+
scalar_l = scalar_type(left[l_idx])
332+
scalar_r = scalar_type(right[r_idx])
333+
expected = scalar_l + scalar_r
334+
if not math.isfinite(expected) or expected <= m or expected >= M:
335+
continue
336+
scalar_o = scalar_type(res[o_idx])
337+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
338+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
339+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
340+
assert isclose(scalar_o, expected), (
341+
f"{f_o}={scalar_o}, but should be roughly ({f_l} + {f_r})={expected} "
342+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
343+
)
317344

318345

319346
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -1487,9 +1514,9 @@ def test_sign(x):
14871514
expr = f"({f_x} / |{f_x}|)={expected}"
14881515
scalar_o = scalar_type(out[idx])
14891516
f_o = sh.fmt_idx("out", idx)
1490-
assert scalar_o == expected, (
1491-
f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}"
1492-
)
1517+
assert (
1518+
scalar_o == expected
1519+
), f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}"
14931520

14941521

14951522
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)