@@ -437,31 +437,52 @@ def test_bitwise_and(
437
437
res = func (left , right )
438
438
439
439
assert_binary_param_dtype (func_name , left , right , right_is_scalar , res , res_name )
440
- if not right_is_scalar :
441
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
442
- shape = broadcast_shapes (left .shape , right .shape )
443
- ph .assert_shape (func_name , res .shape , shape , repr_name = f"{ res_name } .shape" )
444
- _left = xp .broadcast_to (left , shape )
445
- _right = xp .broadcast_to (right , shape )
446
-
447
- # Compare against the Python & operator.
448
- if res .dtype == xp .bool :
449
- for idx in sh .ndindex (res .shape ):
450
- s_left = bool (_left [idx ])
451
- s_right = bool (_right [idx ])
452
- s_res = bool (res [idx ])
453
- assert (s_left and s_right ) == s_res
454
- else :
455
- for idx in sh .ndindex (res .shape ):
456
- s_left = int (_left [idx ])
457
- s_right = int (_right [idx ])
458
- s_res = int (res [idx ])
459
- s_and = ah .int_to_dtype (
460
- s_left & s_right ,
440
+ assert_binary_param_shape (func_name , left , right , right_is_scalar , res , res_name )
441
+ scalar_type = dh .get_scalar_type (res .dtype )
442
+ if right_is_scalar :
443
+ for idx in sh .ndindex (res .shape ):
444
+ scalar_l = scalar_type (left [idx ])
445
+ if res .dtype == xp .bool :
446
+ expected = scalar_l and right
447
+ else :
448
+ # for mypy
449
+ assert isinstance (scalar_l , int )
450
+ assert isinstance (right , int )
451
+ expected = ah .int_to_dtype (
452
+ scalar_l & right ,
453
+ dh .dtype_nbits [res .dtype ],
454
+ dh .dtype_signed [res .dtype ],
455
+ )
456
+ scalar_o = scalar_type (res [idx ])
457
+ f_l = sh .fmt_idx (left_sym , idx )
458
+ f_o = sh .fmt_idx (res_name , idx )
459
+ assert scalar_o == expected , (
460
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } & { right } )={ expected } "
461
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } "
462
+ )
463
+ else :
464
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
465
+ scalar_l = scalar_type (left [l_idx ])
466
+ scalar_r = scalar_type (right [r_idx ])
467
+ if res .dtype == xp .bool :
468
+ expected = scalar_l and scalar_r
469
+ else :
470
+ # for mypy
471
+ assert isinstance (scalar_l , int )
472
+ assert isinstance (scalar_r , int )
473
+ expected = ah .int_to_dtype (
474
+ scalar_l & scalar_r ,
461
475
dh .dtype_nbits [res .dtype ],
462
476
dh .dtype_signed [res .dtype ],
463
477
)
464
- assert s_and == s_res
478
+ scalar_o = scalar_type (res [o_idx ])
479
+ f_l = sh .fmt_idx (left_sym , l_idx )
480
+ f_r = sh .fmt_idx (right_sym , r_idx )
481
+ f_o = sh .fmt_idx (res_name , o_idx )
482
+ assert scalar_o == expected , (
483
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } & { f_r } )={ expected } "
484
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
485
+ )
465
486
466
487
467
488
@pytest .mark .parametrize (
@@ -489,25 +510,41 @@ def test_bitwise_left_shift(
489
510
res = func (left , right )
490
511
491
512
assert_binary_param_dtype (func_name , left , right , right_is_scalar , res , res_name )
492
- if not right_is_scalar :
493
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
494
- shape = broadcast_shapes (left .shape , right .shape )
495
- ph .assert_shape (func_name , res .shape , shape , repr_name = f"{ res_name } .shape" )
496
- _left = xp .broadcast_to (left , shape )
497
- _right = xp .broadcast_to (right , shape )
498
-
499
- # Compare against the Python << operator.
513
+ assert_binary_param_shape (func_name , left , right , right_is_scalar , res , res_name )
514
+ if right_is_scalar :
500
515
for idx in sh .ndindex (res .shape ):
501
- s_left = int (_left [idx ])
502
- s_right = int (_right [idx ])
503
- s_res = int (res [idx ])
504
- s_shift = ah .int_to_dtype (
516
+ scalar_l = int (left [idx ])
517
+ expected = ah .int_to_dtype (
505
518
# We avoid shifting very large ints
506
- s_left << s_right if s_right < dh .dtype_nbits [res .dtype ] else 0 ,
519
+ scalar_l << right if right < dh .dtype_nbits [res .dtype ] else 0 ,
507
520
dh .dtype_nbits [res .dtype ],
508
521
dh .dtype_signed [res .dtype ],
509
522
)
510
- assert s_shift == s_res
523
+ scalar_o = int (res [idx ])
524
+ f_l = sh .fmt_idx (left_sym , idx )
525
+ f_o = sh .fmt_idx (res_name , idx )
526
+ assert scalar_o == expected , (
527
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } << { right } )={ expected } "
528
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } "
529
+ )
530
+ else :
531
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
532
+ scalar_l = int (left [l_idx ])
533
+ scalar_r = int (right [r_idx ])
534
+ expected = ah .int_to_dtype (
535
+ # We avoid shifting very large ints
536
+ scalar_l << scalar_r if scalar_r < dh .dtype_nbits [res .dtype ] else 0 ,
537
+ dh .dtype_nbits [res .dtype ],
538
+ dh .dtype_signed [res .dtype ],
539
+ )
540
+ scalar_o = int (res [o_idx ])
541
+ f_l = sh .fmt_idx (left_sym , l_idx )
542
+ f_r = sh .fmt_idx (right_sym , r_idx )
543
+ f_o = sh .fmt_idx (res_name , o_idx )
544
+ assert scalar_o == expected , (
545
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } << { f_r } )={ expected } "
546
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
547
+ )
511
548
512
549
513
550
@pytest .mark .parametrize (
@@ -522,20 +559,23 @@ def test_bitwise_invert(func_name, func, strat, data):
522
559
523
560
ph .assert_dtype (func_name , x .dtype , out .dtype )
524
561
ph .assert_shape (func_name , out .shape , x .shape )
525
- # Compare against the Python ~ operator.
526
- if out .dtype == xp .bool :
527
- for idx in sh .ndindex (out .shape ):
528
- s_x = bool (x [idx ])
529
- s_out = bool (out [idx ])
530
- assert (not s_x ) == s_out
531
- else :
532
- for idx in sh .ndindex (out .shape ):
533
- s_x = int (x [idx ])
534
- s_out = int (out [idx ])
535
- s_invert = ah .int_to_dtype (
536
- ~ s_x , dh .dtype_nbits [out .dtype ], dh .dtype_signed [out .dtype ]
562
+ for idx in sh .ndindex (out .shape ):
563
+ if out .dtype == xp .bool :
564
+ scalar_x = bool (x [idx ])
565
+ scalar_o = bool (out [idx ])
566
+ expected = not scalar_x
567
+ else :
568
+ scalar_x = int (x [idx ])
569
+ scalar_o = int (out [idx ])
570
+ expected = ah .int_to_dtype (
571
+ ~ scalar_x , dh .dtype_nbits [out .dtype ], dh .dtype_signed [out .dtype ]
537
572
)
538
- assert s_invert == s_out
573
+ f_x = sh .fmt_idx ("x" , idx )
574
+ f_o = sh .fmt_idx ("out" , idx )
575
+ assert scalar_o == expected , (
576
+ f"{ f_o } ={ scalar_o } , but should be ~{ f_x } ={ scalar_x } "
577
+ f"[{ func_name } ()]\n { f_x } ={ scalar_x } "
578
+ )
539
579
540
580
541
581
@pytest .mark .parametrize (
@@ -559,31 +599,50 @@ def test_bitwise_or(
559
599
res = func (left , right )
560
600
561
601
assert_binary_param_dtype (func_name , left , right , right_is_scalar , res , res_name )
562
- if not right_is_scalar :
563
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
564
- shape = broadcast_shapes (left .shape , right .shape )
565
- ph .assert_shape (func_name , res .shape , shape , repr_name = f"{ res_name } .shape" )
566
- _left = xp .broadcast_to (left , shape )
567
- _right = xp .broadcast_to (right , shape )
568
-
569
- # Compare against the Python | operator.
570
- if res .dtype == xp .bool :
571
- for idx in sh .ndindex (res .shape ):
572
- s_left = bool (_left [idx ])
573
- s_right = bool (_right [idx ])
574
- s_res = bool (res [idx ])
575
- assert (s_left or s_right ) == s_res
576
- else :
577
- for idx in sh .ndindex (res .shape ):
578
- s_left = int (_left [idx ])
579
- s_right = int (_right [idx ])
580
- s_res = int (res [idx ])
581
- s_or = ah .int_to_dtype (
582
- s_left | s_right ,
602
+ assert_binary_param_shape (func_name , left , right , right_is_scalar , res , res_name )
603
+ if right_is_scalar :
604
+ for idx in sh .ndindex (res .shape ):
605
+ if res .dtype == xp .bool :
606
+ scalar_l = bool (left [idx ])
607
+ scalar_o = bool (res [idx ])
608
+ expected = scalar_l or right
609
+ else :
610
+ scalar_l = int (left [idx ])
611
+ scalar_o = int (res [idx ])
612
+ expected = ah .int_to_dtype (
613
+ scalar_l | right ,
614
+ dh .dtype_nbits [res .dtype ],
615
+ dh .dtype_signed [res .dtype ],
616
+ )
617
+ f_l = sh .fmt_idx (left_sym , idx )
618
+ f_o = sh .fmt_idx (res_name , idx )
619
+ assert scalar_o == expected , (
620
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } | { right } )={ expected } "
621
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } "
622
+ )
623
+ else :
624
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
625
+ if res .dtype == xp .bool :
626
+ scalar_l = bool (left [l_idx ])
627
+ scalar_r = bool (right [r_idx ])
628
+ scalar_o = bool (res [o_idx ])
629
+ expected = scalar_l or scalar_r
630
+ else :
631
+ scalar_l = int (left [l_idx ])
632
+ scalar_r = int (right [r_idx ])
633
+ scalar_o = int (res [o_idx ])
634
+ expected = ah .int_to_dtype (
635
+ scalar_l | scalar_r ,
583
636
dh .dtype_nbits [res .dtype ],
584
637
dh .dtype_signed [res .dtype ],
585
638
)
586
- assert s_or == s_res
639
+ f_l = sh .fmt_idx (left_sym , l_idx )
640
+ f_r = sh .fmt_idx (right_sym , r_idx )
641
+ f_o = sh .fmt_idx (res_name , o_idx )
642
+ assert scalar_o == expected , (
643
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } | { f_r } )={ expected } "
644
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
645
+ )
587
646
588
647
589
648
@pytest .mark .parametrize (
@@ -611,24 +670,39 @@ def test_bitwise_right_shift(
611
670
res = func (left , right )
612
671
613
672
assert_binary_param_dtype (func_name , left , right , right_is_scalar , res , res_name )
614
- if not right_is_scalar :
615
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
616
- shape = broadcast_shapes (left .shape , right .shape )
617
- ph .assert_shape (
618
- "bitwise_right_shift" , res .shape , shape , repr_name = f"{ res_name } .shape"
619
- )
620
- _left = xp .broadcast_to (left , shape )
621
- _right = xp .broadcast_to (right , shape )
622
-
623
- # Compare against the Python >> operator.
673
+ assert_binary_param_shape (func_name , left , right , right_is_scalar , res , res_name )
674
+ if right_is_scalar :
624
675
for idx in sh .ndindex (res .shape ):
625
- s_left = int (_left [idx ])
626
- s_right = int (_right [idx ])
627
- s_res = int (res [idx ])
628
- s_shift = ah .int_to_dtype (
629
- s_left >> s_right , dh .dtype_nbits [res .dtype ], dh .dtype_signed [res .dtype ]
676
+ scalar_l = int (left [idx ])
677
+ expected = ah .int_to_dtype (
678
+ scalar_l >> right ,
679
+ dh .dtype_nbits [res .dtype ],
680
+ dh .dtype_signed [res .dtype ],
681
+ )
682
+ scalar_o = int (res [idx ])
683
+ f_l = sh .fmt_idx (left_sym , idx )
684
+ f_o = sh .fmt_idx (res_name , idx )
685
+ assert scalar_o == expected , (
686
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } >> { right } )={ expected } "
687
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } "
688
+ )
689
+ else :
690
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
691
+ scalar_l = int (left [l_idx ])
692
+ scalar_r = int (right [r_idx ])
693
+ expected = ah .int_to_dtype (
694
+ scalar_l >> scalar_r ,
695
+ dh .dtype_nbits [res .dtype ],
696
+ dh .dtype_signed [res .dtype ],
697
+ )
698
+ scalar_o = int (res [o_idx ])
699
+ f_l = sh .fmt_idx (left_sym , l_idx )
700
+ f_r = sh .fmt_idx (right_sym , r_idx )
701
+ f_o = sh .fmt_idx (res_name , o_idx )
702
+ assert scalar_o == expected , (
703
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } >> { f_r } )={ expected } "
704
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
630
705
)
631
- assert s_shift == s_res
632
706
633
707
634
708
@pytest .mark .parametrize (
@@ -652,31 +726,50 @@ def test_bitwise_xor(
652
726
res = func (left , right )
653
727
654
728
assert_binary_param_dtype (func_name , left , right , right_is_scalar , res , res_name )
655
- if not right_is_scalar :
656
- # TODO: generate indices without broadcasting arrays (see test_equal comment)
657
- shape = broadcast_shapes (left .shape , right .shape )
658
- ph .assert_shape (func_name , res .shape , shape , repr_name = f"{ res_name } .shape" )
659
- _left = xp .broadcast_to (left , shape )
660
- _right = xp .broadcast_to (right , shape )
661
-
662
- # Compare against the Python ^ operator.
663
- if res .dtype == xp .bool :
664
- for idx in sh .ndindex (res .shape ):
665
- s_left = bool (_left [idx ])
666
- s_right = bool (_right [idx ])
667
- s_res = bool (res [idx ])
668
- assert (s_left ^ s_right ) == s_res
669
- else :
670
- for idx in sh .ndindex (res .shape ):
671
- s_left = int (_left [idx ])
672
- s_right = int (_right [idx ])
673
- s_res = int (res [idx ])
674
- s_xor = ah .int_to_dtype (
675
- s_left ^ s_right ,
729
+ assert_binary_param_shape (func_name , left , right , right_is_scalar , res , res_name )
730
+ if right_is_scalar :
731
+ for idx in sh .ndindex (res .shape ):
732
+ if res .dtype == xp .bool :
733
+ scalar_l = bool (left [idx ])
734
+ scalar_o = bool (res [idx ])
735
+ expected = scalar_l ^ right
736
+ else :
737
+ scalar_l = int (left [idx ])
738
+ scalar_o = int (res [idx ])
739
+ expected = ah .int_to_dtype (
740
+ scalar_l ^ right ,
676
741
dh .dtype_nbits [res .dtype ],
677
742
dh .dtype_signed [res .dtype ],
678
743
)
679
- assert s_xor == s_res
744
+ f_l = sh .fmt_idx (left_sym , idx )
745
+ f_o = sh .fmt_idx (res_name , idx )
746
+ assert scalar_o == expected , (
747
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } ^ { right } )={ expected } "
748
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } "
749
+ )
750
+ else :
751
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
752
+ if res .dtype == xp .bool :
753
+ scalar_l = bool (left [l_idx ])
754
+ scalar_r = bool (right [r_idx ])
755
+ scalar_o = bool (res [o_idx ])
756
+ expected = scalar_l ^ scalar_r
757
+ else :
758
+ scalar_l = int (left [l_idx ])
759
+ scalar_r = int (right [r_idx ])
760
+ scalar_o = int (res [o_idx ])
761
+ expected = ah .int_to_dtype (
762
+ scalar_l ^ scalar_r ,
763
+ dh .dtype_nbits [res .dtype ],
764
+ dh .dtype_signed [res .dtype ],
765
+ )
766
+ f_l = sh .fmt_idx (left_sym , l_idx )
767
+ f_r = sh .fmt_idx (right_sym , r_idx )
768
+ f_o = sh .fmt_idx (res_name , o_idx )
769
+ assert scalar_o == expected , (
770
+ f"{ f_o } ={ scalar_o } , but should be ({ f_l } ^ { f_r } )={ expected } "
771
+ f"[{ func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
772
+ )
680
773
681
774
682
775
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
0 commit comments