Skip to content

Commit 0979dea

Browse files
committed
Values testing for test_sign
1 parent 6a51585 commit 0979dea

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,23 @@ def test_sign(x):
14731473
out = xp.sign(x)
14741474
ph.assert_dtype("sign", x.dtype, out.dtype)
14751475
ph.assert_shape("sign", out.shape, x.shape)
1476-
# TODO
1476+
scalar_type = dh.get_scalar_type(x.dtype)
1477+
for idx in sh.ndindex(x.shape):
1478+
scalar_x = scalar_type(x[idx])
1479+
f_x = sh.fmt_idx("x", idx)
1480+
if math.isnan(scalar_x):
1481+
continue
1482+
if scalar_x == 0:
1483+
expected = 0
1484+
expr = f"{f_x}=0"
1485+
else:
1486+
expected = 1 if scalar_x > 0 else -1
1487+
expr = f"({f_x} / |{f_x}|)={expected}"
1488+
scalar_o = scalar_type(out[idx])
1489+
f_o = sh.fmt_idx("out", idx)
1490+
assert scalar_o == expected, (
1491+
f"{f_o}={scalar_o}, but should be {expr} [sign()]\n{f_x}={scalar_x}"
1492+
)
14771493

14781494

14791495
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)