@@ -697,6 +697,36 @@ def test_groupby_cum_skipna(op, skipna, input, exp):
697
697
tm .assert_series_equal (expected , result )
698
698
699
699
700
+ @pytest .fixture
701
+ def frame ():
702
+ floating = Series (np .random .randn (10 ))
703
+ floating_missing = floating .copy ()
704
+ floating_missing .iloc [2 :7 ] = np .nan
705
+ strings = list ("abcde" ) * 2
706
+ strings_missing = strings [:]
707
+ strings_missing [5 ] = np .nan
708
+
709
+ df = DataFrame (
710
+ {
711
+ "float" : floating ,
712
+ "float_missing" : floating_missing ,
713
+ "int" : [1 , 1 , 1 , 1 , 2 ] * 2 ,
714
+ "datetime" : date_range ("1990-1-1" , periods = 10 ),
715
+ "timedelta" : pd .timedelta_range (1 , freq = "s" , periods = 10 ),
716
+ "string" : strings ,
717
+ "string_missing" : strings_missing ,
718
+ "cat" : Categorical (strings ),
719
+ },
720
+ )
721
+ return df
722
+
723
+
724
+ @pytest .fixture
725
+ def frame_mi (frame ):
726
+ frame .index = MultiIndex .from_product ([range (5 ), range (2 )])
727
+ return frame
728
+
729
+
700
730
@pytest .mark .slow
701
731
@pytest .mark .parametrize (
702
732
"op, args, targop" ,
@@ -707,100 +737,110 @@ def test_groupby_cum_skipna(op, skipna, input, exp):
707
737
("shift" , (1 ,), lambda x : x .shift ()),
708
738
],
709
739
)
710
- def test_cython_transform_frame (op , args , targop ):
711
- s = Series (np .random .randn (1000 ))
712
- s_missing = s .copy ()
713
- s_missing .iloc [2 :10 ] = np .nan
714
- labels = np .random .randint (0 , 50 , size = 1000 ).astype (float )
715
- strings = list ("qwertyuiopasdfghjklz" )
716
- strings_missing = strings [:]
717
- strings_missing [5 ] = np .nan
718
- df = DataFrame (
719
- {
720
- "float" : s ,
721
- "float_missing" : s_missing ,
722
- "int" : [1 , 1 , 1 , 1 , 2 ] * 200 ,
723
- "datetime" : date_range ("1990-1-1" , periods = 1000 ),
724
- "timedelta" : pd .timedelta_range (1 , freq = "s" , periods = 1000 ),
725
- "string" : strings * 50 ,
726
- "string_missing" : strings_missing * 50 ,
727
- },
728
- columns = [
729
- "float" ,
730
- "float_missing" ,
731
- "int" ,
732
- "datetime" ,
733
- "timedelta" ,
734
- "string" ,
735
- "string_missing" ,
736
- ],
737
- )
738
- df ["cat" ] = df ["string" ].astype ("category" )
739
-
740
- df2 = df .copy ()
741
- df2 .index = MultiIndex .from_product ([range (100 ), range (10 )])
742
-
743
- # DataFrame - Single and MultiIndex,
744
- # group by values, index level, columns
745
- for df in [df , df2 ]:
746
- for gb_target in [
747
- {"by" : labels },
748
- {"level" : 0 },
749
- {"by" : "string" },
750
- ]: # {"by": 'string_missing'}]:
751
- # {"by": ['int','string']}]:
752
- # TODO: remove or enable commented-out code
753
-
754
- gb = df .groupby (group_keys = False , ** gb_target )
755
-
756
- if op != "shift" and "int" not in gb_target :
757
- # numeric apply fastpath promotes dtype so have
758
- # to apply separately and concat
759
- i = gb [["int" ]].apply (targop )
760
- f = gb [["float" , "float_missing" ]].apply (targop )
761
- expected = concat ([f , i ], axis = 1 )
762
- else :
763
- expected = gb .apply (targop )
764
-
765
- expected = expected .sort_index (axis = 1 )
766
- if op == "shift" :
767
- expected ["string_missing" ] = expected ["string_missing" ].fillna (
768
- np .nan , downcast = False
769
- )
770
- expected ["string" ] = expected ["string" ].fillna (np .nan , downcast = False )
771
-
772
- result = gb [expected .columns ].transform (op , * args ).sort_index (axis = 1 )
773
- tm .assert_frame_equal (result , expected )
774
- result = getattr (gb [expected .columns ], op )(* args ).sort_index (axis = 1 )
775
- tm .assert_frame_equal (result , expected )
776
- # individual columns
777
- for c in df :
778
- if (
779
- c not in ["float" , "int" , "float_missing" ]
780
- and op != "shift"
781
- and not (c == "timedelta" and op == "cumsum" )
782
- ):
783
- msg = "|" .join (
784
- [
785
- "does not support .* operations" ,
786
- ".* is not supported for object dtype" ,
787
- "is not implemented for this dtype" ,
788
- ]
789
- )
790
- with pytest .raises (TypeError , match = msg ):
791
- gb [c ].transform (op )
792
- with pytest .raises (TypeError , match = msg ):
793
- getattr (gb [c ], op )()
794
- else :
795
- expected = gb [c ].apply (targop )
796
- expected .name = c
797
- if c in ["string_missing" , "string" ]:
798
- expected = expected .fillna (np .nan , downcast = False )
799
-
800
- res = gb [c ].transform (op , * args )
801
- tm .assert_series_equal (expected , res )
802
- res2 = getattr (gb [c ], op )(* args )
803
- tm .assert_series_equal (expected , res2 )
740
+ @pytest .mark .parametrize ("df_fix" , ["frame" , "frame_mi" ])
741
+ @pytest .mark .parametrize (
742
+ "gb_target" ,
743
+ [
744
+ {"by" : np .random .randint (0 , 50 , size = 10 ).astype (float )},
745
+ {"level" : 0 },
746
+ {"by" : "string" },
747
+ # {"by": 'string_missing'}]:
748
+ # {"by": ['int','string']}]:
749
+ # TODO: remove or enable commented-out code
750
+ ],
751
+ )
752
+ def test_cython_transform_frame (request , op , args , targop , df_fix , gb_target ):
753
+ df = request .getfixturevalue (df_fix )
754
+ gb = df .groupby (group_keys = False , ** gb_target )
755
+
756
+ if op != "shift" and "int" not in gb_target :
757
+ # numeric apply fastpath promotes dtype so have
758
+ # to apply separately and concat
759
+ i = gb [["int" ]].apply (targop )
760
+ f = gb [["float" , "float_missing" ]].apply (targop )
761
+ expected = concat ([f , i ], axis = 1 )
762
+ else :
763
+ expected = gb .apply (targop )
764
+
765
+ expected = expected .sort_index (axis = 1 )
766
+ if op == "shift" :
767
+ expected ["string_missing" ] = expected ["string_missing" ].fillna (
768
+ np .nan , downcast = False
769
+ )
770
+ expected ["string" ] = expected ["string" ].fillna (np .nan , downcast = False )
771
+
772
+ result = gb [expected .columns ].transform (op , * args ).sort_index (axis = 1 )
773
+ tm .assert_frame_equal (result , expected )
774
+ result = getattr (gb [expected .columns ], op )(* args ).sort_index (axis = 1 )
775
+ tm .assert_frame_equal (result , expected )
776
+
777
+
778
+ @pytest .mark .slow
779
+ @pytest .mark .parametrize (
780
+ "op, args, targop" ,
781
+ [
782
+ ("cumprod" , (), lambda x : x .cumprod ()),
783
+ ("cumsum" , (), lambda x : x .cumsum ()),
784
+ ("shift" , (- 1 ,), lambda x : x .shift (- 1 )),
785
+ ("shift" , (1 ,), lambda x : x .shift ()),
786
+ ],
787
+ )
788
+ @pytest .mark .parametrize ("df_fix" , ["frame" , "frame_mi" ])
789
+ @pytest .mark .parametrize (
790
+ "gb_target" ,
791
+ [
792
+ {"by" : np .random .randint (0 , 50 , size = 10 ).astype (float )},
793
+ {"level" : 0 },
794
+ {"by" : "string" },
795
+ # {"by": 'string_missing'}]:
796
+ # {"by": ['int','string']}]:
797
+ # TODO: remove or enable commented-out code
798
+ ],
799
+ )
800
+ @pytest .mark .parametrize (
801
+ "column" ,
802
+ [
803
+ "float" ,
804
+ "float_missing" ,
805
+ "int" ,
806
+ "datetime" ,
807
+ "timedelta" ,
808
+ "string" ,
809
+ "string_missing" ,
810
+ ],
811
+ )
812
+ def test_cython_transform_frame_column (
813
+ request , op , args , targop , df_fix , gb_target , column
814
+ ):
815
+ df = request .getfixturevalue (df_fix )
816
+ gb = df .groupby (group_keys = False , ** gb_target )
817
+ c = column
818
+ if (
819
+ c not in ["float" , "int" , "float_missing" ]
820
+ and op != "shift"
821
+ and not (c == "timedelta" and op == "cumsum" )
822
+ ):
823
+ msg = "|" .join (
824
+ [
825
+ "does not support .* operations" ,
826
+ ".* is not supported for object dtype" ,
827
+ "is not implemented for this dtype" ,
828
+ ]
829
+ )
830
+ with pytest .raises (TypeError , match = msg ):
831
+ gb [c ].transform (op )
832
+ with pytest .raises (TypeError , match = msg ):
833
+ getattr (gb [c ], op )()
834
+ else :
835
+ expected = gb [c ].apply (targop )
836
+ expected .name = c
837
+ if c in ["string_missing" , "string" ]:
838
+ expected = expected .fillna (np .nan , downcast = False )
839
+
840
+ res = gb [c ].transform (op , * args )
841
+ tm .assert_series_equal (expected , res )
842
+ res2 = getattr (gb [c ], op )(* args )
843
+ tm .assert_series_equal (expected , res2 )
804
844
805
845
806
846
def test_transform_with_non_scalar_group ():
0 commit comments