@@ -1198,25 +1198,59 @@ def test_transform_lambda_indexing():
1198
1198
tm .assert_frame_equal (result , expected )
1199
1199
1200
1200
1201
- @pytest .mark .parametrize (
1202
- "input_df" ,
1203
- [
1204
- DataFrame (
1205
- {
1206
- "A" : [121 , 121 , 121 , 121 , 231 , 231 , 676 ],
1207
- "B" : [1 , 2 , np .nan , 3 , 3 , np .nan , 4 ],
1208
- }
1209
- ),
1210
- DataFrame (
1211
- {
1212
- "A" : [121 , 121 , 121 , 121 , 231 , 231 , 676 ],
1213
- "B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ],
1214
- }
1215
- ),
1216
- ],
1217
- )
1218
- def test_groupby_transform_fillna (input_df ):
1219
- # GH 27905
1220
- result = input_df .groupby ("A" ).transform (lambda x : x .fillna (x .mean ()))
1221
- expected = pd .DataFrame ({"B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ]})
1222
- tm .assert_frame_equal (result , expected )
1201
+ def test_transform_nan_tshift_corrwith (transformation_func ):
1202
+
1203
+ df1 = DataFrame (
1204
+ {
1205
+ "A" : [121 , 121 , 121 , 121 , 231 , 231 , 676 ],
1206
+ "B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ],
1207
+ }
1208
+ )
1209
+ g1 = df1 .groupby ("A" )
1210
+
1211
+ if transformation_func == "corrwith" :
1212
+ result = g1 .corrwith (df1 )
1213
+ expected = pd .DataFrame (dict (B = [1 , np .nan , np .nan ], A = [np .nan ] * 3 ))
1214
+ expected .index = pd .Index ([121 , 231 , 676 ], name = "A" )
1215
+ tm .assert_frame_equal (result , expected )
1216
+
1217
+ if transformation_func == "fillna" :
1218
+ df3 = df1 .copy ()
1219
+ df3 ["B" ] = [1 , np .nan , np .nan , 3 , np .nan , 3 , 4 ]
1220
+ result = df3 .groupby ("A" ).transform (lambda x : x .fillna (x .mean ()))
1221
+ expected = pd .DataFrame ({"B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ]})
1222
+ tm .assert_frame_equal (result , expected )
1223
+
1224
+ result = df3 .groupby ("A" ).transform (transformation_func , value = 1 )
1225
+ expected = pd .DataFrame ({"B" : [1.0 , 1.0 , 1.0 , 3.0 , 1.0 , 3.0 , 4.0 ]})
1226
+ tm .assert_frame_equal (result , expected )
1227
+
1228
+ if transformation_func == "tshift" :
1229
+ df2 = df1 .copy ()
1230
+ dt_periods = pd .date_range ("2013-11-03" , periods = 7 , freq = "D" )
1231
+ df2 ["C" ] = dt_periods
1232
+ result = df2 .set_index ("C" ).groupby ("A" ).tshift (2 , "D" )
1233
+ df2 ["C" ] = dt_periods + dt_periods .freq * 2
1234
+ expected = df2
1235
+ tm .assert_frame_equal (
1236
+ result .reset_index ().reindex (columns = ["A" , "B" , "C" ]), expected
1237
+ )
1238
+
1239
+
1240
+ def test_check_original_and_transformed_index (transformation_func ):
1241
+ df = DataFrame ({"A" : [0 , 0 , 0 , 1 , 1 , 1 ], "B" : [0 , 1 , 2 , 3 , 4 , 5 ]})
1242
+ g = df .groupby ("A" )
1243
+
1244
+ if transformation_func in [
1245
+ "cummax" ,
1246
+ "cummin" ,
1247
+ "cumprod" ,
1248
+ "cumsum" ,
1249
+ "diff" ,
1250
+ "ffill" ,
1251
+ "pct_change" ,
1252
+ "rank" ,
1253
+ "shift" ,
1254
+ ]:
1255
+ result = g .transform (transformation_func )
1256
+ tm .assert_index_equal (result .index , df .index )
0 commit comments