Skip to content

Commit 4d7a03a

Browse files
authored
BUG: Fix some cases of groupby(...).transform with dropna=True (#46209)
1 parent bb9a985 commit 4d7a03a

File tree

6 files changed

+256
-8
lines changed

6 files changed

+256
-8
lines changed

doc/source/whatsnew/v1.5.0.rst

+15
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,32 @@ did not have the same index as the input.
8282

8383
.. code-block:: ipython
8484
85+
In [3]: df.groupby('a', dropna=True).transform(lambda x: x.sum())
86+
Out[3]:
87+
b
88+
0 5
89+
1 5
90+
8591
In [3]: df.groupby('a', dropna=True).transform(lambda x: x)
8692
Out[3]:
8793
b
8894
0 2
8995
1 3
9096
97+
In [3]: df.groupby('a', dropna=True).transform('sum')
98+
Out[3]:
99+
b
100+
0 5
101+
1 5
102+
2 5
103+
91104
*New behavior*:
92105

93106
.. ipython:: python
94107
108+
df.groupby('a', dropna=True).transform(lambda x: x.sum())
95109
df.groupby('a', dropna=True).transform(lambda x: x)
110+
df.groupby('a', dropna=True).transform('sum')
96111
97112
.. _whatsnew_150.notable_bug_fixes.notable_bug_fix2:
98113

pandas/core/generic.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -3693,10 +3693,26 @@ class max_speed
36933693

36943694
nv.validate_take((), kwargs)
36953695

3696+
return self._take(indices, axis)
3697+
3698+
def _take(
3699+
self: NDFrameT,
3700+
indices,
3701+
axis=0,
3702+
convert_indices: bool_t = True,
3703+
) -> NDFrameT:
3704+
"""
3705+
Internal version of the `take` allowing specification of additional args.
3706+
3707+
See the docstring of `take` for full explanation of the parameters.
3708+
"""
36963709
self._consolidate_inplace()
36973710

36983711
new_data = self._mgr.take(
3699-
indices, axis=self._get_block_manager_axis(axis), verify=True
3712+
indices,
3713+
axis=self._get_block_manager_axis(axis),
3714+
verify=True,
3715+
convert_indices=convert_indices,
37003716
)
37013717
return self._constructor(new_data).__finalize__(self, method="take")
37023718

@@ -3708,7 +3724,7 @@ def _take_with_is_copy(self: NDFrameT, indices, axis=0) -> NDFrameT:
37083724
37093725
See the docstring of `take` for full explanation of the parameters.
37103726
"""
3711-
result = self.take(indices=indices, axis=axis)
3727+
result = self._take(indices=indices, axis=axis)
37123728
# Maybe set copy if we didn't actually change the index.
37133729
if not result._get_axis(axis).equals(self._get_axis(axis)):
37143730
result._set_is_copy(self)

pandas/core/groupby/groupby.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,10 @@ def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT:
16461646
out = algorithms.take_nd(result._values, ids)
16471647
output = obj._constructor(out, index=obj.index, name=obj.name)
16481648
else:
1649-
output = result.take(ids, axis=self.axis)
1649+
# GH#46209
1650+
# Don't convert indices: negative indices need to give rise
1651+
# to null values in the result
1652+
output = result._take(ids, axis=self.axis, convert_indices=False)
16501653
output = output.set_axis(obj._get_axis(self.axis), axis=self.axis)
16511654
return output
16521655

@@ -1699,9 +1702,14 @@ def _cumcount_array(self, ascending: bool = True) -> np.ndarray:
16991702
else:
17001703
out = np.repeat(out[np.r_[run[1:], True]], rep) - out
17011704

1705+
if self.grouper.has_dropped_na:
1706+
out = np.where(ids == -1, np.nan, out.astype(np.float64, copy=False))
1707+
else:
1708+
out = out.astype(np.int64, copy=False)
1709+
17021710
rev = np.empty(count, dtype=np.intp)
17031711
rev[sorter] = np.arange(count, dtype=np.intp)
1704-
return out[rev].astype(np.int64, copy=False)
1712+
return out[rev]
17051713

17061714
# -----------------------------------------------------------------
17071715

pandas/core/internals/array_manager.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,13 @@ def _reindex_indexer(
640640

641641
return type(self)(new_arrays, new_axes, verify_integrity=False)
642642

643-
def take(self: T, indexer, axis: int = 1, verify: bool = True) -> T:
643+
def take(
644+
self: T,
645+
indexer,
646+
axis: int = 1,
647+
verify: bool = True,
648+
convert_indices: bool = True,
649+
) -> T:
644650
"""
645651
Take items along any axis.
646652
"""
@@ -656,7 +662,8 @@ def take(self: T, indexer, axis: int = 1, verify: bool = True) -> T:
656662
raise ValueError("indexer should be 1-dimensional")
657663

658664
n = self.shape_proper[axis]
659-
indexer = maybe_convert_indices(indexer, n, verify=verify)
665+
if convert_indices:
666+
indexer = maybe_convert_indices(indexer, n, verify=verify)
660667

661668
new_labels = self._axes[axis].take(indexer)
662669
return self._reindex_indexer(

pandas/core/internals/managers.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,13 @@ def _make_na_block(
829829
block_values.fill(fill_value)
830830
return new_block_2d(block_values, placement=placement)
831831

832-
def take(self: T, indexer, axis: int = 1, verify: bool = True) -> T:
832+
def take(
833+
self: T,
834+
indexer,
835+
axis: int = 1,
836+
verify: bool = True,
837+
convert_indices: bool = True,
838+
) -> T:
833839
"""
834840
Take items along any axis.
835841
@@ -838,6 +844,8 @@ def take(self: T, indexer, axis: int = 1, verify: bool = True) -> T:
838844
verify : bool, default True
839845
Check that all entries are between 0 and len(self) - 1, inclusive.
840846
Pass verify=False if this check has been done by the caller.
847+
convert_indices : bool, default True
848+
Whether to attempt to convert indices to positive values.
841849
842850
Returns
843851
-------
@@ -851,7 +859,8 @@ def take(self: T, indexer, axis: int = 1, verify: bool = True) -> T:
851859
)
852860

853861
n = self.shape[axis]
854-
indexer = maybe_convert_indices(indexer, n, verify=verify)
862+
if convert_indices:
863+
indexer = maybe_convert_indices(indexer, n, verify=verify)
855864

856865
new_labels = self.axes[axis].take(indexer)
857866
return self.reindex_indexer(

pandas/tests/groupby/transform/test_transform.py

+193
Original file line numberDiff line numberDiff line change
@@ -1308,3 +1308,196 @@ def test_null_group_lambda_self(sort, dropna):
13081308
gb = df.groupby("A", dropna=dropna, sort=sort)
13091309
result = gb.transform(lambda x: x)
13101310
tm.assert_frame_equal(result, expected)
1311+
1312+
1313+
def test_null_group_str_reducer(request, dropna, reduction_func):
1314+
# GH 17093
1315+
if reduction_func in ("corrwith", "ngroup"):
1316+
msg = "incorrectly raises"
1317+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1318+
index = [1, 2, 3, 4] # test transform preserves non-standard index
1319+
df = DataFrame({"A": [1, 1, np.nan, np.nan], "B": [1, 2, 2, 3]}, index=index)
1320+
gb = df.groupby("A", dropna=dropna)
1321+
1322+
if reduction_func == "corrwith":
1323+
args = (df["B"],)
1324+
elif reduction_func == "nth":
1325+
args = (0,)
1326+
else:
1327+
args = ()
1328+
1329+
# Manually handle reducers that don't fit the generic pattern
1330+
# Set expected with dropna=False, then replace if necessary
1331+
if reduction_func == "first":
1332+
expected = DataFrame({"B": [1, 1, 2, 2]}, index=index)
1333+
elif reduction_func == "last":
1334+
expected = DataFrame({"B": [2, 2, 3, 3]}, index=index)
1335+
elif reduction_func == "nth":
1336+
expected = DataFrame({"B": [1, 1, 2, 2]}, index=index)
1337+
elif reduction_func == "size":
1338+
expected = Series([2, 2, 2, 2], index=index)
1339+
elif reduction_func == "corrwith":
1340+
expected = DataFrame({"B": [1.0, 1.0, 1.0, 1.0]}, index=index)
1341+
else:
1342+
expected_gb = df.groupby("A", dropna=False)
1343+
buffer = []
1344+
for idx, group in expected_gb:
1345+
res = getattr(group["B"], reduction_func)()
1346+
buffer.append(Series(res, index=group.index))
1347+
expected = concat(buffer).to_frame("B")
1348+
if dropna:
1349+
dtype = object if reduction_func in ("any", "all") else float
1350+
expected = expected.astype(dtype)
1351+
if expected.ndim == 2:
1352+
expected.iloc[[2, 3], 0] = np.nan
1353+
else:
1354+
expected.iloc[[2, 3]] = np.nan
1355+
1356+
result = gb.transform(reduction_func, *args)
1357+
tm.assert_equal(result, expected)
1358+
1359+
1360+
def test_null_group_str_transformer(
1361+
request, using_array_manager, dropna, transformation_func
1362+
):
1363+
# GH 17093
1364+
xfails_block = (
1365+
"cummax",
1366+
"cummin",
1367+
"cumsum",
1368+
"fillna",
1369+
"rank",
1370+
"backfill",
1371+
"ffill",
1372+
"bfill",
1373+
"pad",
1374+
)
1375+
xfails_array = ("cummax", "cummin", "cumsum", "fillna", "rank")
1376+
if transformation_func == "tshift":
1377+
msg = "tshift requires timeseries"
1378+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1379+
elif dropna and (
1380+
(not using_array_manager and transformation_func in xfails_block)
1381+
or (using_array_manager and transformation_func in xfails_array)
1382+
):
1383+
msg = "produces incorrect results when nans are present"
1384+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1385+
args = (0,) if transformation_func == "fillna" else ()
1386+
df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3])
1387+
gb = df.groupby("A", dropna=dropna)
1388+
1389+
buffer = []
1390+
for k, (idx, group) in enumerate(gb):
1391+
if transformation_func == "cumcount":
1392+
# DataFrame has no cumcount method
1393+
res = DataFrame({"B": range(len(group))}, index=group.index)
1394+
elif transformation_func == "ngroup":
1395+
res = DataFrame(len(group) * [k], index=group.index, columns=["B"])
1396+
else:
1397+
res = getattr(group[["B"]], transformation_func)(*args)
1398+
buffer.append(res)
1399+
if dropna:
1400+
dtype = object if transformation_func in ("any", "all") else None
1401+
buffer.append(DataFrame([[np.nan]], index=[3], dtype=dtype, columns=["B"]))
1402+
expected = concat(buffer)
1403+
1404+
if transformation_func in ("cumcount", "ngroup"):
1405+
# ngroup/cumcount always returns a Series as it counts the groups, not values
1406+
expected = expected["B"].rename(None)
1407+
1408+
warn = FutureWarning if transformation_func in ("backfill", "pad") else None
1409+
msg = f"{transformation_func} is deprecated"
1410+
with tm.assert_produces_warning(warn, match=msg):
1411+
result = gb.transform(transformation_func, *args)
1412+
1413+
tm.assert_equal(result, expected)
1414+
1415+
1416+
def test_null_group_str_reducer_series(request, dropna, reduction_func):
1417+
# GH 17093
1418+
if reduction_func == "corrwith":
1419+
msg = "corrwith not implemented for SeriesGroupBy"
1420+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1421+
1422+
if reduction_func == "ngroup":
1423+
msg = "ngroup fails"
1424+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1425+
1426+
# GH 17093
1427+
index = [1, 2, 3, 4] # test transform preserves non-standard index
1428+
ser = Series([1, 2, 2, 3], index=index)
1429+
gb = ser.groupby([1, 1, np.nan, np.nan], dropna=dropna)
1430+
1431+
if reduction_func == "corrwith":
1432+
args = (ser,)
1433+
elif reduction_func == "nth":
1434+
args = (0,)
1435+
else:
1436+
args = ()
1437+
1438+
# Manually handle reducers that don't fit the generic pattern
1439+
# Set expected with dropna=False, then replace if necessary
1440+
if reduction_func == "first":
1441+
expected = Series([1, 1, 2, 2], index=index)
1442+
elif reduction_func == "last":
1443+
expected = Series([2, 2, 3, 3], index=index)
1444+
elif reduction_func == "nth":
1445+
expected = Series([1, 1, 2, 2], index=index)
1446+
elif reduction_func == "size":
1447+
expected = Series([2, 2, 2, 2], index=index)
1448+
elif reduction_func == "corrwith":
1449+
expected = Series([1, 1, 2, 2], index=index)
1450+
else:
1451+
expected_gb = ser.groupby([1, 1, np.nan, np.nan], dropna=False)
1452+
buffer = []
1453+
for idx, group in expected_gb:
1454+
res = getattr(group, reduction_func)()
1455+
buffer.append(Series(res, index=group.index))
1456+
expected = concat(buffer)
1457+
if dropna:
1458+
dtype = object if reduction_func in ("any", "all") else float
1459+
expected = expected.astype(dtype)
1460+
expected.iloc[[2, 3]] = np.nan
1461+
1462+
result = gb.transform(reduction_func, *args)
1463+
tm.assert_series_equal(result, expected)
1464+
1465+
1466+
def test_null_group_str_transformer_series(request, dropna, transformation_func):
1467+
# GH 17093
1468+
if transformation_func == "tshift":
1469+
msg = "tshift requires timeseries"
1470+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1471+
elif dropna and transformation_func in (
1472+
"cummax",
1473+
"cummin",
1474+
"cumsum",
1475+
"fillna",
1476+
"rank",
1477+
):
1478+
msg = "produces incorrect results when nans are present"
1479+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1480+
args = (0,) if transformation_func == "fillna" else ()
1481+
ser = Series([1, 2, 2], index=[1, 2, 3])
1482+
gb = ser.groupby([1, 1, np.nan], dropna=dropna)
1483+
1484+
buffer = []
1485+
for k, (idx, group) in enumerate(gb):
1486+
if transformation_func == "cumcount":
1487+
# Series has no cumcount method
1488+
res = Series(range(len(group)), index=group.index)
1489+
elif transformation_func == "ngroup":
1490+
res = Series(k, index=group.index)
1491+
else:
1492+
res = getattr(group, transformation_func)(*args)
1493+
buffer.append(res)
1494+
if dropna:
1495+
dtype = object if transformation_func in ("any", "all") else None
1496+
buffer.append(Series([np.nan], index=[3], dtype=dtype))
1497+
expected = concat(buffer)
1498+
1499+
warn = FutureWarning if transformation_func in ("backfill", "pad") else None
1500+
msg = f"{transformation_func} is deprecated"
1501+
with tm.assert_produces_warning(warn, match=msg):
1502+
result = gb.transform(transformation_func, *args)
1503+
tm.assert_equal(result, expected)

0 commit comments

Comments
 (0)