@@ -803,39 +803,29 @@ def test_transform_with_non_scalar_group():
803
803
804
804
805
805
@pytest .mark .parametrize (
806
- "cols,exp,comp_func " ,
806
+ "cols,expected " ,
807
807
[
808
- ("a" , Series ([1 , 1 , 1 ], name = "a" ), tm . assert_series_equal ),
808
+ ("a" , Series ([1 , 1 , 1 ], name = "a" )),
809
809
(
810
810
["a" , "c" ],
811
811
DataFrame ({"a" : [1 , 1 , 1 ], "c" : [1 , 1 , 1 ]}),
812
- tm .assert_frame_equal ,
813
812
),
814
813
],
815
814
)
816
815
@pytest .mark .parametrize ("agg_func" , ["count" , "rank" , "size" ])
817
- def test_transform_numeric_ret (cols , exp , comp_func , agg_func , request ):
818
- if agg_func == "size" and isinstance (cols , list ):
819
- # https://github.com/pytest-dev/pytest/issues/6300
820
- # workaround to xfail fixture/param permutations
821
- reason = "'size' transformation not supported with NDFrameGroupy"
822
- request .node .add_marker (pytest .mark .xfail (reason = reason ))
823
-
824
- # GH 19200
816
+ def test_transform_numeric_ret (cols , expected , agg_func ):
817
+ # GH#19200 and GH#27469
825
818
df = DataFrame (
826
819
{"a" : date_range ("2018-01-01" , periods = 3 ), "b" : range (3 ), "c" : range (7 , 10 )}
827
820
)
828
-
829
- warn = FutureWarning
830
- if isinstance (exp , Series ) or agg_func != "size" :
831
- warn = None
832
- with tm .assert_produces_warning (warn , match = "Dropping invalid columns" ):
833
- result = df .groupby ("b" )[cols ].transform (agg_func )
821
+ result = df .groupby ("b" )[cols ].transform (agg_func )
834
822
835
823
if agg_func == "rank" :
836
- exp = exp .astype ("float" )
837
-
838
- comp_func (result , exp )
824
+ expected = expected .astype ("float" )
825
+ elif agg_func == "size" and cols == ["a" , "c" ]:
826
+ # transform("size") returns a Series
827
+ expected = expected ["a" ].rename (None )
828
+ tm .assert_equal (result , expected )
839
829
840
830
841
831
def test_transform_ffill ():
@@ -1131,27 +1121,19 @@ def test_transform_agg_by_name(request, reduction_func, obj):
1131
1121
request .node .add_marker (
1132
1122
pytest .mark .xfail (reason = "TODO: g.transform('ngroup') doesn't work" )
1133
1123
)
1134
- if func == "size" and obj .ndim == 2 : # GH#27469
1135
- request .node .add_marker (
1136
- pytest .mark .xfail (reason = "TODO: g.transform('size') doesn't work" )
1137
- )
1138
1124
if func == "corrwith" and isinstance (obj , Series ): # GH#32293
1139
1125
request .node .add_marker (
1140
1126
pytest .mark .xfail (reason = "TODO: implement SeriesGroupBy.corrwith" )
1141
1127
)
1142
1128
1143
1129
args = {"nth" : [0 ], "quantile" : [0.5 ], "corrwith" : [obj ]}.get (func , [])
1144
-
1145
- warn = None
1146
- if isinstance (obj , DataFrame ) and func == "size" :
1147
- warn = FutureWarning
1148
-
1149
- with tm .assert_produces_warning (warn , match = "Dropping invalid columns" ):
1150
- result = g .transform (func , * args )
1130
+ result = g .transform (func , * args )
1151
1131
1152
1132
# this is the *definition* of a transformation
1153
1133
tm .assert_index_equal (result .index , obj .index )
1154
- if hasattr (obj , "columns" ):
1134
+
1135
+ if func != "size" and obj .ndim == 2 :
1136
+ # size returns a Series, unlike other transforms
1155
1137
tm .assert_index_equal (result .columns , obj .columns )
1156
1138
1157
1139
# verify that values were broadcasted across each group
0 commit comments