@@ -544,231 +544,120 @@ def _assert_where_conversion(self, original, cond, values,
544
544
res = target .where (cond , values )
545
545
self ._assert (res , expected , expected_dtype )
546
546
547
- def _where_object_common (self , klass ):
547
+ @pytest .mark .parametrize ("klass" , [pd .Series , pd .Index ])
548
+ @pytest .mark .parametrize ("fill_val,exp_dtype" , [
549
+ (1 , np .object ),
550
+ (1.1 , np .object ),
551
+ (1 + 1j , np .object ),
552
+ (True , np .object )])
553
+ def test_where_object (self , klass , fill_val ,exp_dtype ):
548
554
obj = klass (list ('abcd' ))
549
555
assert obj .dtype == np .object
550
556
cond = klass ([True , False , True , False ])
551
557
552
- # object + int -> object
553
- exp = klass (['a' , 1 , 'c' , 1 ])
554
- self ._assert_where_conversion (obj , cond , 1 , exp , np .object )
555
-
556
- values = klass ([5 , 6 , 7 , 8 ])
557
- exp = klass (['a' , 6 , 'c' , 8 ])
558
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
559
-
560
- # object + float -> object
561
- exp = klass (['a' , 1.1 , 'c' , 1.1 ])
562
- self ._assert_where_conversion (obj , cond , 1.1 , exp , np .object )
563
-
564
- values = klass ([5.5 , 6.6 , 7.7 , 8.8 ])
565
- exp = klass (['a' , 6.6 , 'c' , 8.8 ])
566
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
567
-
568
- # object + complex -> object
569
- exp = klass (['a' , 1 + 1j , 'c' , 1 + 1j ])
570
- self ._assert_where_conversion (obj , cond , 1 + 1j , exp , np .object )
571
-
572
- values = klass ([5 + 5j , 6 + 6j , 7 + 7j , 8 + 8j ])
573
- exp = klass (['a' , 6 + 6j , 'c' , 8 + 8j ])
574
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
575
-
576
- if klass is pd .Series :
577
- exp = klass (['a' , 1 , 'c' , 1 ])
578
- self ._assert_where_conversion (obj , cond , True , exp , np .object )
558
+ if fill_val is True and klass is pd .Series :
559
+ ret_val = 1
560
+ else :
561
+ ret_val = fill_val
579
562
580
- values = klass ([True , False , True , True ])
581
- exp = klass (['a' , 0 , 'c' , 1 ])
582
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
583
- elif klass is pd .Index :
584
- # object + bool -> object
585
- exp = klass (['a' , True , 'c' , True ])
586
- self ._assert_where_conversion (obj , cond , True , exp , np .object )
563
+ exp = klass (['a' , ret_val , 'c' , ret_val ])
564
+ self ._assert_where_conversion (obj , cond , fill_val , exp , exp_dtype )
587
565
566
+ if fill_val is True :
588
567
values = klass ([True , False , True , True ])
589
- exp = klass (['a' , False , 'c' , True ])
590
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
591
568
else :
592
- NotImplementedError
569
+ values = klass (fill_val * x for x in [5 , 6 , 7 , 8 ])
570
+
571
+ exp = klass (['a' , values [1 ], 'c' , values [3 ]])
572
+ self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
593
573
594
- def test_where_series_object (self ):
595
- self ._where_object_common (pd .Series )
596
-
597
- def test_where_index_object (self ):
598
- self ._where_object_common (pd .Index )
599
-
600
- def _where_int64_common (self , klass ):
574
+ @pytest .mark .parametrize ("klass" , [pd .Series , pd .Index ])
575
+ @pytest .mark .parametrize ("fill_val,exp_dtype" , [
576
+ (1 , np .int64 ),
577
+ (1.1 , np .float64 ),
578
+ (1 + 1j , np .complex128 ),
579
+ (True , np .object )])
580
+ def test_where_int64 (self , klass , fill_val , exp_dtype ):
581
+ if klass is pd .Index and exp_dtype is np .complex128 :
582
+ pytest .skip ("Complex Index not supported" )
601
583
obj = klass ([1 , 2 , 3 , 4 ])
602
584
assert obj .dtype == np .int64
603
585
cond = klass ([True , False , True , False ])
604
586
605
- # int + int -> int
606
- exp = klass ([1 , 1 , 3 , 1 ])
607
- self ._assert_where_conversion (obj , cond , 1 , exp , np .int64 )
608
-
609
- values = klass ([5 , 6 , 7 , 8 ])
610
- exp = klass ([1 , 6 , 3 , 8 ])
611
- self ._assert_where_conversion (obj , cond , values , exp , np .int64 )
612
-
613
- # int + float -> float
614
- exp = klass ([1 , 1.1 , 3 , 1.1 ])
615
- self ._assert_where_conversion (obj , cond , 1.1 , exp , np .float64 )
616
-
617
- values = klass ([5.5 , 6.6 , 7.7 , 8.8 ])
618
- exp = klass ([1 , 6.6 , 3 , 8.8 ])
619
- self ._assert_where_conversion (obj , cond , values , exp , np .float64 )
620
-
621
- # int + complex -> complex
622
- if klass is pd .Series :
623
- exp = klass ([1 , 1 + 1j , 3 , 1 + 1j ])
624
- self ._assert_where_conversion (obj , cond , 1 + 1j , exp ,
625
- np .complex128 )
626
-
627
- values = klass ([5 + 5j , 6 + 6j , 7 + 7j , 8 + 8j ])
628
- exp = klass ([1 , 6 + 6j , 3 , 8 + 8j ])
629
- self ._assert_where_conversion (obj , cond , values , exp ,
630
- np .complex128 )
631
-
632
- # int + bool -> object
633
- exp = klass ([1 , True , 3 , True ])
634
- self ._assert_where_conversion (obj , cond , True , exp , np .object )
587
+ exp = klass ([1 , fill_val , 3 , fill_val ])
588
+ self ._assert_where_conversion (obj , cond , fill_val , exp , exp_dtype )
635
589
636
- values = klass ([True , False , True , True ])
637
- exp = klass ([1 , False , 3 , True ])
638
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
639
-
640
- def test_where_series_int64 (self ):
641
- self ._where_int64_common (pd .Series )
642
-
643
- def test_where_index_int64 (self ):
644
- self ._where_int64_common (pd .Index )
590
+ if fill_val is True :
591
+ values = klass ([True , False , True , True ])
592
+ else :
593
+ values = klass (x * fill_val for x in [5 , 6 , 7 , 8 ])
594
+ exp = klass ([1 , values [1 ], 3 , values [3 ]])
595
+ self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
645
596
646
- def _where_float64_common (self , klass ):
597
+ @pytest .mark .parametrize ("klass" , [pd .Series , pd .Index ])
598
+ @pytest .mark .parametrize ("fill_val, exp_dtype" , [
599
+ (1 , np .float64 ),
600
+ (1.1 , np .float64 ),
601
+ (1 + 1j , np .complex128 ),
602
+ (True , np .object )])
603
+ def test_where_float64 (self , klass , fill_val , exp_dtype ):
604
+ if klass is pd .Index and exp_dtype is np .complex128 :
605
+ pytest .skip ("Complex Index not supported" )
647
606
obj = klass ([1.1 , 2.2 , 3.3 , 4.4 ])
648
607
assert obj .dtype == np .float64
649
608
cond = klass ([True , False , True , False ])
650
609
651
- # float + int -> float
652
- exp = klass ([1.1 , 1.0 , 3.3 , 1.0 ])
653
- self ._assert_where_conversion (obj , cond , 1 , exp , np .float64 )
610
+ exp = klass ([1.1 , fill_val , 3.3 , fill_val ])
611
+ self ._assert_where_conversion (obj , cond , fill_val , exp , exp_dtype )
654
612
655
- values = klass ([5 , 6 , 7 , 8 ])
656
- exp = klass ([1.1 , 6.0 , 3.3 , 8.0 ])
657
- self ._assert_where_conversion (obj , cond , values , exp , np .float64 )
658
-
659
- # float + float -> float
660
- exp = klass ([1.1 , 1.1 , 3.3 , 1.1 ])
661
- self ._assert_where_conversion (obj , cond , 1.1 , exp , np .float64 )
662
-
663
- values = klass ([5.5 , 6.6 , 7.7 , 8.8 ])
664
- exp = klass ([1.1 , 6.6 , 3.3 , 8.8 ])
665
- self ._assert_where_conversion (obj , cond , values , exp , np .float64 )
666
-
667
- # float + complex -> complex
668
- if klass is pd .Series :
669
- exp = klass ([1.1 , 1 + 1j , 3.3 , 1 + 1j ])
670
- self ._assert_where_conversion (obj , cond , 1 + 1j , exp ,
671
- np .complex128 )
672
-
673
- values = klass ([5 + 5j , 6 + 6j , 7 + 7j , 8 + 8j ])
674
- exp = klass ([1.1 , 6 + 6j , 3.3 , 8 + 8j ])
675
- self ._assert_where_conversion (obj , cond , values , exp ,
676
- np .complex128 )
677
-
678
- # float + bool -> object
679
- exp = klass ([1.1 , True , 3.3 , True ])
680
- self ._assert_where_conversion (obj , cond , True , exp , np .object )
681
-
682
- values = klass ([True , False , True , True ])
683
- exp = klass ([1.1 , False , 3.3 , True ])
684
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
685
-
686
- def test_where_series_float64 (self ):
687
- self ._where_float64_common (pd .Series )
688
-
689
- def test_where_index_float64 (self ):
690
- self ._where_float64_common (pd .Index )
613
+ if fill_val is True :
614
+ values = klass ([True , False , True , True ])
615
+ else :
616
+ values = klass (x * fill_val for x in [5 , 6 , 7 , 8 ])
617
+ exp = klass ([1.1 , values [1 ], 3.3 , values [3 ]])
618
+ self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
691
619
692
- def test_where_series_complex128 (self ):
620
+ @pytest .mark .parametrize ("fill_val,exp_dtype" , [
621
+ (1 , np .complex128 ),
622
+ (1.1 , np .complex128 ),
623
+ (1 + 1j , np .complex128 ),
624
+ (True , np .object )])
625
+ def test_where_series_complex128 (self , fill_val , exp_dtype ):
693
626
obj = pd .Series ([1 + 1j , 2 + 2j , 3 + 3j , 4 + 4j ])
694
627
assert obj .dtype == np .complex128
695
628
cond = pd .Series ([True , False , True , False ])
696
629
697
630
# complex + int -> complex
698
- exp = pd .Series ([1 + 1j , 1 , 3 + 3j , 1 ])
699
- self ._assert_where_conversion (obj , cond , 1 , exp , np .complex128 )
700
-
701
- values = pd .Series ([5 , 6 , 7 , 8 ])
702
- exp = pd .Series ([1 + 1j , 6.0 , 3 + 3j , 8.0 ])
703
- self ._assert_where_conversion (obj , cond , values , exp , np .complex128 )
704
-
705
- # complex + float -> complex
706
- exp = pd .Series ([1 + 1j , 1.1 , 3 + 3j , 1.1 ])
707
- self ._assert_where_conversion (obj , cond , 1.1 , exp , np .complex128 )
708
-
709
- values = pd .Series ([5.5 , 6.6 , 7.7 , 8.8 ])
710
- exp = pd .Series ([1 + 1j , 6.6 , 3 + 3j , 8.8 ])
711
- self ._assert_where_conversion (obj , cond , values , exp , np .complex128 )
712
-
713
- # complex + complex -> complex
714
- exp = pd .Series ([1 + 1j , 1 + 1j , 3 + 3j , 1 + 1j ])
715
- self ._assert_where_conversion (obj , cond , 1 + 1j , exp , np .complex128 )
716
-
717
- values = pd .Series ([5 + 5j , 6 + 6j , 7 + 7j , 8 + 8j ])
718
- exp = pd .Series ([1 + 1j , 6 + 6j , 3 + 3j , 8 + 8j ])
719
- self ._assert_where_conversion (obj , cond , values , exp , np .complex128 )
720
-
721
- # complex + bool -> object
722
- exp = pd .Series ([1 + 1j , True , 3 + 3j , True ])
723
- self ._assert_where_conversion (obj , cond , True , exp , np .object )
724
-
725
- values = pd .Series ([True , False , True , True ])
726
- exp = pd .Series ([1 + 1j , False , 3 + 3j , True ])
727
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
631
+ exp = pd .Series ([1 + 1j , fill_val , 3 + 3j , fill_val ])
632
+ self ._assert_where_conversion (obj , cond , fill_val , exp , exp_dtype )
728
633
729
- def test_where_index_complex128 (self ):
730
- pass
634
+ if fill_val is True :
635
+ values = pd .Series ([True , False , True , True ])
636
+ else :
637
+ values = pd .Series (x * fill_val for x in [5 , 6 , 7 , 8 ])
638
+ exp = pd .Series ([1 + 1j , values [1 ], 3 + 3j , values [3 ]])
639
+ self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
731
640
732
- def test_where_series_bool (self ):
641
+ @pytest .mark .parametrize ("fill_val,exp_dtype" , [
642
+ (1 , np .object ),
643
+ (1.1 , np .object ),
644
+ (1 + 1j , np .object ),
645
+ (True , np .bool )])
646
+ def test_where_series_bool (self , fill_val , exp_dtype ):
733
647
734
648
obj = pd .Series ([True , False , True , False ])
735
649
assert obj .dtype == np .bool
736
650
cond = pd .Series ([True , False , True , False ])
737
651
738
- # bool + int -> object
739
- exp = pd .Series ([True , 1 , True , 1 ])
740
- self ._assert_where_conversion (obj , cond , 1 , exp , np .object )
652
+ exp = pd .Series ([True , fill_val , True , fill_val ])
653
+ self ._assert_where_conversion (obj , cond , fill_val , exp , exp_dtype )
741
654
742
- values = pd .Series ([5 , 6 , 7 , 8 ])
743
- exp = pd .Series ([True , 6 , True , 8 ])
744
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
745
-
746
- # bool + float -> object
747
- exp = pd .Series ([True , 1.1 , True , 1.1 ])
748
- self ._assert_where_conversion (obj , cond , 1.1 , exp , np .object )
749
-
750
- values = pd .Series ([5.5 , 6.6 , 7.7 , 8.8 ])
751
- exp = pd .Series ([True , 6.6 , True , 8.8 ])
752
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
753
-
754
- # bool + complex -> object
755
- exp = pd .Series ([True , 1 + 1j , True , 1 + 1j ])
756
- self ._assert_where_conversion (obj , cond , 1 + 1j , exp , np .object )
757
-
758
- values = pd .Series ([5 + 5j , 6 + 6j , 7 + 7j , 8 + 8j ])
759
- exp = pd .Series ([True , 6 + 6j , True , 8 + 8j ])
760
- self ._assert_where_conversion (obj , cond , values , exp , np .object )
761
-
762
- # bool + bool -> bool
763
- exp = pd .Series ([True , True , True , True ])
764
- self ._assert_where_conversion (obj , cond , True , exp , np .bool )
765
-
766
- values = pd .Series ([True , False , True , True ])
767
- exp = pd .Series ([True , False , True , True ])
768
- self ._assert_where_conversion (obj , cond , values , exp , np .bool )
769
-
770
- def test_where_index_bool (self ):
771
- pass
655
+ if fill_val is True :
656
+ values = pd .Series ([True , False , True , True ])
657
+ else :
658
+ values = pd .Series (x * fill_val for x in [5 , 6 , 7 , 8 ])
659
+ exp = pd .Series ([True , values [1 ], True , values [3 ]])
660
+ self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
772
661
773
662
def test_where_series_datetime64 (self ):
774
663
obj = pd .Series ([pd .Timestamp ('2011-01-01' ),
0 commit comments