Skip to content

Commit 84fcf98

Browse files
committed
Update test_remainder
1 parent 554cfd9 commit 84fcf98

File tree

1 file changed

+34
-9
lines changed

1 file changed

+34
-9
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
3939
return xps.boolean_dtypes() | all_integer_dtypes()
4040

4141

42-
def isclose(n1: float, n2: float):
42+
def isclose(n1: Union[int, float], n2: Union[int, float]):
4343
if not (math.isfinite(n1) and math.isfinite(n2)):
4444
raise ValueError(f"{n1=} and {n1=}, but input must be finite")
4545
return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1)
@@ -1394,20 +1394,45 @@ def test_remainder(ctx, data):
13941394
left = data.draw(ctx.left_strat, label=ctx.left_sym)
13951395
right = data.draw(ctx.right_strat, label=ctx.right_sym)
13961396
if ctx.right_is_scalar:
1397-
out_dtype = left.dtype
1397+
assume(right != 0)
13981398
else:
1399-
out_dtype = dh.result_type(left.dtype, right.dtype)
1400-
if dh.is_int_dtype(out_dtype):
1401-
if ctx.right_is_scalar:
1402-
assume(right != 0)
1403-
else:
1404-
assume(not ah.any(right == 0))
1399+
assume(not ah.any(right == 0))
14051400

14061401
res = ctx.func(left, right)
14071402

14081403
assert_binary_param_dtype(ctx, left, right, res)
14091404
assert_binary_param_shape(ctx, left, right, res)
1410-
# TODO: test results
1405+
scalar_type = dh.get_scalar_type(res.dtype)
1406+
if ctx.right_is_scalar:
1407+
for idx in sh.ndindex(res.shape):
1408+
scalar_l = scalar_type(left[idx])
1409+
expected = scalar_l % right
1410+
scalar_o = scalar_type(res[idx])
1411+
if not all(math.isfinite(n) for n in [scalar_l, right, scalar_o, expected]):
1412+
continue
1413+
f_l = sh.fmt_idx(ctx.left_sym, idx)
1414+
f_o = sh.fmt_idx(ctx.res_name, idx)
1415+
assert isclose(scalar_o, expected), (
1416+
f"{f_o}={scalar_o}, but should be roughly ({f_l} % {right})={expected} "
1417+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
1418+
)
1419+
else:
1420+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
1421+
scalar_l = scalar_type(left[l_idx])
1422+
scalar_r = scalar_type(right[r_idx])
1423+
expected = scalar_l % scalar_r
1424+
scalar_o = scalar_type(res[o_idx])
1425+
if not all(
1426+
math.isfinite(n) for n in [scalar_l, scalar_r, scalar_o, expected]
1427+
):
1428+
continue
1429+
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
1430+
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
1431+
f_o = sh.fmt_idx(ctx.res_name, o_idx)
1432+
assert isclose(scalar_o, expected), (
1433+
f"{f_o}={scalar_o}, but should be roughly ({f_l} % {f_r})={expected} "
1434+
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1435+
)
14111436

14121437

14131438
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)