Skip to content

Commit 8e375b7

Browse files
committed
Add test_equal()
I can't be completely sure that this test is correct until we support testing arbitrary arrays in the elementwise tests (#17), although it should be. This also corresponds to the true type promotion test for equal(), since the test in test_type_promotion.py cannot actually test that equal() does the correct type promotion internally.
1 parent 64204a2 commit 8e375b7

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

array_api_tests/test_elementwise_functions.py

+53-7
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232
infinity, isnegative, all as array_all, any as
3333
array_any, int_to_dtype, bool as bool_dtype,
3434
assert_integral, less_equal, isintegral,
35-
isfinite)
35+
isfinite, ndindex, promote_dtypes,
36+
is_integer_dtype, is_float_dtype)
37+
# We might as well use this implementation rather than requiring
38+
# mod.broadcast_shapes(). See test_equal() and others.
39+
from .test_broadcasting import broadcast_shapes
40+
41+
from .test_type_promotion import promotion_table, dtype_mapping
3642

3743
from . import _array_module
3844

@@ -56,11 +62,9 @@ def two_array_scalars(draw, dtype1, dtype2):
5662
return draw(array_scalars(just(dtype1))), draw(array_scalars(just(dtype2)))
5763

5864
def sanity_check(x1, x2):
59-
from .test_type_promotion import promotion_table, dtype_mapping
60-
t1 = [i for i in dtype_mapping if dtype_mapping[i] == x1.dtype][0]
61-
t2 = [i for i in dtype_mapping if dtype_mapping[i] == x2.dtype][0]
62-
63-
if (t1, t2) not in promotion_table:
65+
try:
66+
promote_dtypes(x1.dtype, x2.dtype)
67+
except ValueError:
6468
raise RuntimeError("Error in test generation (probably a bug in the test suite")
6569

6670
@given(numeric_scalars)
@@ -353,7 +357,49 @@ def test_divide(args):
353357
def test_equal(args):
354358
x1, x2 = args
355359
sanity_check(x1, x2)
356-
# a = _array_module.equal(x1, x2)
360+
a = _array_module.equal(x1, x2)
361+
# NOTE: assert_exactly_equal() itself uses equal(), so we must be careful
362+
# not to use it here. Otherwise, the test would be circular and
363+
# meaningless. Instead, we implement this by iterating every element of
364+
# the arrays and comparing them. The logic here is also used for the tests
365+
# for the other elementwise functions that accept any input dtype but
366+
# always return bool (greater(), greater_equal(), less(), less_equal(),
367+
# and not_equal()).
368+
369+
# First we broadcast the arrays so that they can be indexed uniformly.
370+
# TODO: it should be possible to skip this step if we instead generate
371+
# indices to x1 and x2 that correspond to the broadcasted shapes. This
372+
# would avoid the dependence in this test on broadcast_to().
373+
shape = broadcast_shapes(x1.shape, x2.shape)
374+
_x1 = _array_module.broadcast_to(x1, shape)
375+
_x2 = _array_module.broadcast_to(x2, shape)
376+
377+
# Second, manually promote the dtypes. This is important. If the internal
378+
# type promotion in equal() is wrong, it will not be directly visible in
379+
# the output type, but it can lead to wrong answers. For example,
380+
# equal(array(1.0, dtype=float32), array(1.00000001, dtype=float64)) will
381+
# be wrong if the float64 is downcast to float32. See the comment on
382+
# test_elementwise_function_two_arg_bool_type_promotion() in
383+
# test_type_promotion.py. The type promotion for equal() is not tested in
384+
# that file because doing so requires doing the consistency check we do
385+
# here.
386+
promoted_dtype = promote_dtypes(x1.dtype, x2.dtype)
387+
_x1 = _array_module.asarray(_x1, dtype=promoted_dtype)
388+
_x2 = _array_module.asarray(_x2, dtype=promoted_dtype)
389+
390+
if is_integer_dtype(promoted_dtype):
391+
scalar_func = int
392+
elif is_float_dtype(promoted_dtype):
393+
scalar_func = float
394+
else:
395+
scalar_func = bool
396+
for idx in ndindex(shape):
397+
# Sanity check
398+
aidx = a[idx]
399+
x1idx = _x1[idx]
400+
x2idx = _x2[idx]
401+
assert aidx.shape == x1idx.shape == x2idx.shape
402+
assert bool(aidx) == (scalar_func(x1idx) == scalar_func(x2idx))
357403

358404
@given(floating_scalars)
359405
def test_exp(x):

array_api_tests/test_type_promotion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,8 @@ def dtype_signed(dtype):
363363
elementwise_function_two_arg_bool_parametrize_ids = ['-'.join((n, d1, d2)) for n, (d1, d2)
364364
in elementwise_function_two_arg_bool_parametrize_inputs]
365365

366-
# TODO: These functions should still do type promotion internally, but
367-
# we do not test this here (it will be tested in the corresponding tests in
366+
# TODO: These functions should still do type promotion internally, but we do
367+
# not test this here (it is tested in the corresponding tests in
368368
# test_elementwise_functions.py). This can affect the resulting values if not
369369
# done correctly. For example, greater_equal(array(1.0, dtype=float32),
370370
# array(1.00000001, dtype=float64)) will be wrong if the float64 array is

0 commit comments

Comments
 (0)