@@ -312,6 +312,7 @@ def test_add(
312
312
reject ()
313
313
314
314
assert_binary_param_dtype (func_name , left , right , right_is_scalar , res , res_name )
315
+ assert_binary_param_shape (func_name , left , right , right_is_scalar , res , res_name )
315
316
if not right_is_scalar :
316
317
# add is commutative
317
318
expected = func (right , left )
@@ -773,16 +774,28 @@ def test_equal(
773
774
func_name , left , right , right_is_scalar , out , res_name , xp .bool
774
775
)
775
776
assert_binary_param_shape (func_name , left , right , right_is_scalar , out , res_name )
776
- if not right_is_scalar :
777
+ if right_is_scalar :
778
+ scalar_type = dh .get_scalar_type (left .dtype )
779
+ for idx in sh .ndindex (left .shape ):
780
+ scalar_l = scalar_type (left [idx ])
781
+ expected = scalar_l == right
782
+ scalar_o = bool (out [idx ])
783
+ f_l = sh .fmt_idx (left_sym , idx )
784
+ f_o = sh .fmt_idx (res_name , idx )
785
+ assert scalar_o == expected , (
786
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } == { right } )={ expected } "
787
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } "
788
+ )
789
+ else :
777
790
# We manually promote the dtypes as incorrect internal type promotion
778
- # could lead to erroneous behaviour that we don't catch . For example
791
+ # could lead to false positives . For example
779
792
#
780
793
# >>> xp.equal(
781
794
# ... xp.asarray(1.0, dtype=xp.float32),
782
795
# ... xp.asarray(1.00000001, dtype=xp.float64),
783
796
# ... )
784
797
#
785
- # would incorrectly be True if float64 downcasts to float32 internally .
798
+ # would erroneously be True if float64 downcasted to float32.
786
799
promoted_dtype = dh .promotion_table [left .dtype , right .dtype ]
787
800
_left = xp .astype (left , promoted_dtype )
788
801
_right = xp .astype (right , promoted_dtype )
@@ -792,11 +805,12 @@ def test_equal(
792
805
scalar_r = scalar_type (_right [r_idx ])
793
806
expected = scalar_l == scalar_r
794
807
scalar_o = bool (out [o_idx ])
808
+ f_l = sh .fmt_idx (left_sym , l_idx )
809
+ f_r = sh .fmt_idx (right_sym , r_idx )
810
+ f_o = sh .fmt_idx (res_name , o_idx )
795
811
assert scalar_o == expected , (
796
- f"out[{ o_idx } ]={ scalar_o } , but should be "
797
- f"{ left_sym } [{ l_idx } ]=={ right_sym } [{ r_idx } ]={ expected } "
798
- f"({ left_sym } [{ l_idx } ]={ scalar_l } , { right_sym } [{ r_idx } ]={ scalar_r } ) "
799
- f"[{ func_name } ()]"
812
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } == { f_r } )={ expected } "
813
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
800
814
)
801
815
802
816
@@ -1311,25 +1325,37 @@ def test_not_equal(
1311
1325
assert_binary_param_dtype (
1312
1326
func_name , left , right , right_is_scalar , out , res_name , xp .bool
1313
1327
)
1314
- if not right_is_scalar :
1315
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
1316
-
1317
- shape = broadcast_shapes (left .shape , right .shape )
1318
- ph .assert_shape (func_name , out .shape , shape )
1319
- _left = xp .broadcast_to (left , shape )
1320
- _right = xp .broadcast_to (right , shape )
1321
-
1328
+ assert_binary_param_shape (func_name , left , right , right_is_scalar , out , res_name )
1329
+ if right_is_scalar :
1330
+ scalar_type = dh .get_scalar_type (left .dtype )
1331
+ for idx in sh .ndindex (left .shape ):
1332
+ scalar_l = scalar_type (left [idx ])
1333
+ expected = scalar_l != right
1334
+ scalar_o = bool (out [idx ])
1335
+ f_l = sh .fmt_idx (left_sym , idx )
1336
+ f_o = sh .fmt_idx (res_name , idx )
1337
+ assert scalar_o == expected , (
1338
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } != { right } )={ expected } "
1339
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } "
1340
+ )
1341
+ else :
1342
+ # See test_equal note
1322
1343
promoted_dtype = dh .promotion_table [left .dtype , right .dtype ]
1323
- _left = ah .asarray (_left , dtype = promoted_dtype )
1324
- _right = ah .asarray (_right , dtype = promoted_dtype )
1325
-
1344
+ _left = xp .astype (left , promoted_dtype )
1345
+ _right = xp .astype (right , promoted_dtype )
1326
1346
scalar_type = dh .get_scalar_type (promoted_dtype )
1327
- for idx in sh .ndindex (shape ):
1328
- out_idx = out [idx ]
1329
- x1_idx = _left [idx ]
1330
- x2_idx = _right [idx ]
1331
- assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
1332
- assert bool (out_idx ) == (scalar_type (x1_idx ) != scalar_type (x2_idx ))
1347
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , out .shape ):
1348
+ scalar_l = scalar_type (_left [l_idx ])
1349
+ scalar_r = scalar_type (_right [r_idx ])
1350
+ expected = scalar_l != scalar_r
1351
+ scalar_o = bool (out [o_idx ])
1352
+ f_l = sh .fmt_idx (left_sym , l_idx )
1353
+ f_r = sh .fmt_idx (right_sym , r_idx )
1354
+ f_o = sh .fmt_idx (res_name , o_idx )
1355
+ assert scalar_o == expected , (
1356
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } != { f_r } )={ expected } "
1357
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
1358
+ )
1333
1359
1334
1360
1335
1361
@pytest .mark .parametrize (
0 commit comments