Skip to content

Commit 1e228bc

Browse files
committed
Better values testing in test_not_equal
1 parent f524872 commit 1e228bc

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

array_api_tests/shape_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from functools import lru_cache
23
from itertools import product
34
from typing import Iterator, List, Optional, Tuple, Union
45

@@ -27,7 +28,7 @@ def normalise_axis(
2728

2829

2930
def ndindex(shape):
30-
# TODO: remove
31+
"""Yield every index of shape"""
3132
return (indices[0] for indices in iter_indices(shape))
3233

3334

@@ -105,6 +106,7 @@ def fmt_i(i: AtomicIndex) -> str:
105106
return "..."
106107

107108

109+
@lru_cache
108110
def fmt_idx(sym: str, idx: Index) -> str:
109111
if idx == ():
110112
return sym

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def test_add(
312312
reject()
313313

314314
assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name)
315+
assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name)
315316
if not right_is_scalar:
316317
# add is commutative
317318
expected = func(right, left)
@@ -773,16 +774,28 @@ def test_equal(
773774
func_name, left, right, right_is_scalar, out, res_name, xp.bool
774775
)
775776
assert_binary_param_shape(func_name, left, right, right_is_scalar, out, res_name)
776-
if not right_is_scalar:
777+
if right_is_scalar:
778+
scalar_type = dh.get_scalar_type(left.dtype)
779+
for idx in sh.ndindex(left.shape):
780+
scalar_l = scalar_type(left[idx])
781+
expected = scalar_l == right
782+
scalar_o = bool(out[idx])
783+
f_l = sh.fmt_idx(left_sym, idx)
784+
f_o = sh.fmt_idx(res_name, idx)
785+
assert scalar_o == expected, (
786+
f"{f_o}={scalar_o}, but should be ({f_l} == {right})={expected} "
787+
f"[{func_name}()]\n{f_l}={scalar_l}"
788+
)
789+
else:
777790
# We manually promote the dtypes as incorrect internal type promotion
778-
# could lead to erroneous behaviour that we don't catch. For example
791+
# could lead to false positives. For example
779792
#
780793
# >>> xp.equal(
781794
# ... xp.asarray(1.0, dtype=xp.float32),
782795
# ... xp.asarray(1.00000001, dtype=xp.float64),
783796
# ... )
784797
#
785-
# would incorrectly be True if float64 downcasts to float32 internally.
798+
# would erroneously be True if float64 downcasted to float32.
786799
promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
787800
_left = xp.astype(left, promoted_dtype)
788801
_right = xp.astype(right, promoted_dtype)
@@ -792,11 +805,12 @@ def test_equal(
792805
scalar_r = scalar_type(_right[r_idx])
793806
expected = scalar_l == scalar_r
794807
scalar_o = bool(out[o_idx])
808+
f_l = sh.fmt_idx(left_sym, l_idx)
809+
f_r = sh.fmt_idx(right_sym, r_idx)
810+
f_o = sh.fmt_idx(res_name, o_idx)
795811
assert scalar_o == expected, (
796-
f"out[{o_idx}]={scalar_o}, but should be "
797-
f"{left_sym}[{l_idx}]=={right_sym}[{r_idx}]={expected} "
798-
f"({left_sym}[{l_idx}]={scalar_l}, {right_sym}[{r_idx}]={scalar_r}) "
799-
f"[{func_name}()]"
812+
f"{f_o}={scalar_o}, but should be ({f_l} == {f_r})={expected} "
813+
f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
800814
)
801815

802816

@@ -1311,25 +1325,37 @@ def test_not_equal(
13111325
assert_binary_param_dtype(
13121326
func_name, left, right, right_is_scalar, out, res_name, xp.bool
13131327
)
1314-
if not right_is_scalar:
1315-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
1316-
1317-
shape = broadcast_shapes(left.shape, right.shape)
1318-
ph.assert_shape(func_name, out.shape, shape)
1319-
_left = xp.broadcast_to(left, shape)
1320-
_right = xp.broadcast_to(right, shape)
1321-
1328+
assert_binary_param_shape(func_name, left, right, right_is_scalar, out, res_name)
1329+
if right_is_scalar:
1330+
scalar_type = dh.get_scalar_type(left.dtype)
1331+
for idx in sh.ndindex(left.shape):
1332+
scalar_l = scalar_type(left[idx])
1333+
expected = scalar_l != right
1334+
scalar_o = bool(out[idx])
1335+
f_l = sh.fmt_idx(left_sym, idx)
1336+
f_o = sh.fmt_idx(res_name, idx)
1337+
assert scalar_o == expected, (
1338+
f"{f_o}={scalar_o}, but should be ({f_l} != {right})={expected} "
1339+
f"[{func_name}()]\n{f_l}={scalar_l}"
1340+
)
1341+
else:
1342+
# See test_equal note
13221343
promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
1323-
_left = ah.asarray(_left, dtype=promoted_dtype)
1324-
_right = ah.asarray(_right, dtype=promoted_dtype)
1325-
1344+
_left = xp.astype(left, promoted_dtype)
1345+
_right = xp.astype(right, promoted_dtype)
13261346
scalar_type = dh.get_scalar_type(promoted_dtype)
1327-
for idx in sh.ndindex(shape):
1328-
out_idx = out[idx]
1329-
x1_idx = _left[idx]
1330-
x2_idx = _right[idx]
1331-
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
1332-
assert bool(out_idx) == (scalar_type(x1_idx) != scalar_type(x2_idx))
1347+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, out.shape):
1348+
scalar_l = scalar_type(_left[l_idx])
1349+
scalar_r = scalar_type(_right[r_idx])
1350+
expected = scalar_l != scalar_r
1351+
scalar_o = bool(out[o_idx])
1352+
f_l = sh.fmt_idx(left_sym, l_idx)
1353+
f_r = sh.fmt_idx(right_sym, r_idx)
1354+
f_o = sh.fmt_idx(res_name, o_idx)
1355+
assert scalar_o == expected, (
1356+
f"{f_o}={scalar_o}, but should be ({f_l} != {f_r})={expected} "
1357+
f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1358+
)
13331359

13341360

13351361
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)