@@ -1303,23 +1303,34 @@ def test_transform_cumcount():
1303
1303
tm .assert_series_equal (result , expected )
1304
1304
1305
1305
1306
- def test_null_group_lambda_self (sort , dropna ):
1306
+ @pytest .mark .parametrize ("keys" , [["A1" ], ["A1" , "A2" ]])
1307
+ def test_null_group_lambda_self (request , sort , dropna , keys ):
1307
1308
# GH 17093
1308
- np .random .seed (0 )
1309
- keys = np .random .randint (0 , 5 , size = 50 ).astype (float )
1310
- nulls = np .random .choice ([0 , 1 ], keys .shape ).astype (bool )
1311
- keys [nulls ] = np .nan
1312
- values = np .random .randint (0 , 5 , size = keys .shape )
1313
- df = DataFrame ({"A" : keys , "B" : values })
1309
+ if not sort and not dropna :
1310
+ msg = "GH#46584: null values get sorted when sort=False"
1311
+ request .node .add_marker (pytest .mark .xfail (reason = msg , strict = False ))
1312
+
1313
+ size = 50
1314
+ nulls1 = np .random .choice ([False , True ], size )
1315
+ nulls2 = np .random .choice ([False , True ], size )
1316
+ # Whether a group contains a null value or not
1317
+ nulls_grouper = nulls1 if len (keys ) == 1 else nulls1 | nulls2
1318
+
1319
+ a1 = np .random .randint (0 , 5 , size = size ).astype (float )
1320
+ a1 [nulls1 ] = np .nan
1321
+ a2 = np .random .randint (0 , 5 , size = size ).astype (float )
1322
+ a2 [nulls2 ] = np .nan
1323
+ values = np .random .randint (0 , 5 , size = a1 .shape )
1324
+ df = DataFrame ({"A1" : a1 , "A2" : a2 , "B" : values })
1314
1325
1315
1326
expected_values = values
1316
- if dropna and nulls .any ():
1327
+ if dropna and nulls_grouper .any ():
1317
1328
expected_values = expected_values .astype (float )
1318
- expected_values [nulls ] = np .nan
1329
+ expected_values [nulls_grouper ] = np .nan
1319
1330
expected = DataFrame (expected_values , columns = ["B" ])
1320
1331
1321
- gb = df .groupby ("A" , dropna = dropna , sort = sort )
1322
- result = gb .transform (lambda x : x )
1332
+ gb = df .groupby (keys , dropna = dropna , sort = sort )
1333
+ result = gb [[ "B" ]] .transform (lambda x : x )
1323
1334
tm .assert_frame_equal (result , expected )
1324
1335
1325
1336
0 commit comments