Skip to content

Commit 4420817

Browse files
committed
Remove mask usage in ph.assert_array()
1 parent f943eb8 commit 4420817

File tree

3 files changed

+35
-12
lines changed

3 files changed

+35
-12
lines changed

array_api_tests/meta/test_pytest_helpers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pytest import raises
22

3-
from .. import pytest_helpers as ph
43
from .. import _array_module as xp
4+
from .. import pytest_helpers as ph
55

66

77
def test_assert_dtype():
@@ -11,3 +11,12 @@ def test_assert_dtype():
1111
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
1212
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
1313
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
14+
15+
16+
def test_assert_array():
17+
ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0))
18+
ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
19+
with raises(AssertionError):
20+
ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
21+
with raises(AssertionError):
22+
ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))

array_api_tests/pytest_helpers.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"assert_shape",
2525
"assert_result_shape",
2626
"assert_keepdimable_shape",
27+
"assert_0d_equals",
2728
"assert_fill",
2829
"assert_array",
2930
]
@@ -242,15 +243,28 @@ def assert_fill(
242243
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
243244
assert_dtype(func_name, out.dtype, expected.dtype)
244245
assert_shape(func_name, out.shape, expected.shape, **kw)
245-
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
246+
f_func = f"[{func_name}({fmt_kw(kw)})]"
246247
if dh.is_float_dtype(out.dtype):
247-
neg_zeros = expected == -0.0
248-
assert xp.all((out == -0.0) == neg_zeros), msg
249-
pos_zeros = expected == +0.0
250-
assert xp.all((out == +0.0) == pos_zeros), msg
251-
nans = xp.isnan(expected)
252-
assert xp.all(xp.isnan(out) == nans), msg
253-
mask = ~(neg_zeros | pos_zeros | nans)
254-
assert xp.all(out[mask] == expected[mask]), msg
248+
for idx in sh.ndindex(out.shape):
249+
at_out = out[idx]
250+
at_expected = expected[idx]
251+
msg = (
252+
f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} "
253+
f"{f_func}"
254+
)
255+
if xp.isnan(at_expected):
256+
assert xp.isnan(at_out), msg
257+
elif at_expected == 0.0 or at_expected == -0.0:
258+
scalar_at_expected = float(at_expected)
259+
scalar_at_out = float(at_out)
260+
if is_pos_zero(scalar_at_expected):
261+
assert is_pos_zero(scalar_at_out), msg
262+
else:
263+
assert is_neg_zero(scalar_at_expected) # sanity check
264+
assert is_neg_zero(scalar_at_out), msg
265+
else:
266+
assert at_out == at_expected, msg
255267
else:
256-
assert xp.all(out == expected), msg
268+
assert xp.all(out == expected), (
269+
f"out not as expected {f_func}\n" f"{out=}\n{expected=}"
270+
)

array_api_tests/test_creation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_asarray_arrays(x, data):
280280
if copy:
281281
assert not xp.all(
282282
out == x
283-
), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
283+
), f"xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
284284
elif copy is False:
285285
pass # TODO
286286

0 commit comments

Comments
 (0)