32
32
infinity , isnegative , all as array_all , any as
33
33
array_any , int_to_dtype , bool as bool_dtype ,
34
34
assert_integral , less_equal , isintegral ,
35
- isfinite )
35
+ isfinite , ndindex , promote_dtypes ,
36
+ is_integer_dtype , is_float_dtype )
37
+ # We might as well use this implementation rather than requiring
38
+ # mod.broadcast_shapes(). See test_equal() and others.
39
+ from .test_broadcasting import broadcast_shapes
40
+
41
+ from .test_type_promotion import promotion_table , dtype_mapping
36
42
37
43
from . import _array_module
38
44
@@ -56,11 +62,9 @@ def two_array_scalars(draw, dtype1, dtype2):
56
62
return draw (array_scalars (just (dtype1 ))), draw (array_scalars (just (dtype2 )))
57
63
58
64
def sanity_check (x1 , x2 ):
59
- from .test_type_promotion import promotion_table , dtype_mapping
60
- t1 = [i for i in dtype_mapping if dtype_mapping [i ] == x1 .dtype ][0 ]
61
- t2 = [i for i in dtype_mapping if dtype_mapping [i ] == x2 .dtype ][0 ]
62
-
63
- if (t1 , t2 ) not in promotion_table :
65
+ try :
66
+ promote_dtypes (x1 .dtype , x2 .dtype )
67
+ except ValueError :
64
68
raise RuntimeError ("Error in test generation (probably a bug in the test suite" )
65
69
66
70
@given (numeric_scalars )
@@ -353,7 +357,49 @@ def test_divide(args):
353
357
def test_equal (args ):
354
358
x1 , x2 = args
355
359
sanity_check (x1 , x2 )
356
- # a = _array_module.equal(x1, x2)
360
+ a = _array_module .equal (x1 , x2 )
361
+ # NOTE: assert_exactly_equal() itself uses equal(), so we must be careful
362
+ # not to use it here. Otherwise, the test would be circular and
363
+ # meaningless. Instead, we implement this by iterating every element of
364
+ # the arrays and comparing them. The logic here is also used for the tests
365
+ # for the other elementwise functions that accept any input dtype but
366
+ # always return bool (greater(), greater_equal(), less(), less_equal(),
367
+ # and not_equal()).
368
+
369
+ # First we broadcast the arrays so that they can be indexed uniformly.
370
+ # TODO: it should be possible to skip this step if we instead generate
371
+ # indices to x1 and x2 that correspond to the broadcasted shapes. This
372
+ # would avoid the dependence in this test on broadcast_to().
373
+ shape = broadcast_shapes (x1 .shape , x2 .shape )
374
+ _x1 = _array_module .broadcast_to (x1 , shape )
375
+ _x2 = _array_module .broadcast_to (x2 , shape )
376
+
377
+ # Second, manually promote the dtypes. This is important. If the internal
378
+ # type promotion in equal() is wrong, it will not be directly visible in
379
+ # the output type, but it can lead to wrong answers. For example,
380
+ # equal(array(1.0, dtype=float32), array(1.00000001, dtype=float64)) will
381
+ # be wrong if the float64 is downcast to float32. See the comment on
382
+ # test_elementwise_function_two_arg_bool_type_promotion() in
383
+ # test_type_promotion.py. The type promotion for equal() is not tested in
384
+ # that file because doing so requires doing the consistency check we do
385
+ # here.
386
+ promoted_dtype = promote_dtypes (x1 .dtype , x2 .dtype )
387
+ _x1 = _array_module .asarray (_x1 , dtype = promoted_dtype )
388
+ _x2 = _array_module .asarray (_x2 , dtype = promoted_dtype )
389
+
390
+ if is_integer_dtype (promoted_dtype ):
391
+ scalar_func = int
392
+ elif is_float_dtype (promoted_dtype ):
393
+ scalar_func = float
394
+ else :
395
+ scalar_func = bool
396
+ for idx in ndindex (shape ):
397
+ # Sanity check
398
+ aidx = a [idx ]
399
+ x1idx = _x1 [idx ]
400
+ x2idx = _x2 [idx ]
401
+ assert aidx .shape == x1idx .shape == x2idx .shape
402
+ assert bool (aidx ) == (scalar_func (x1idx ) == scalar_func (x2idx ))
357
403
358
404
@given (floating_scalars )
359
405
def test_exp (x ):
0 commit comments