Skip to content

Commit 8714948

Browse files
authored
BUG: Fix remaining cases of groupby(...).transform with dropna=True (#46367)
1 parent 1d90a59 commit 8714948

File tree

8 files changed

+120
-67
lines changed

8 files changed

+120
-67
lines changed

doc/source/whatsnew/v1.5.0.rst

+21-9
Original file line numberDiff line numberDiff line change
@@ -83,32 +83,44 @@ did not have the same index as the input.
8383

8484
.. code-block:: ipython
8585
86-
In [3]: df.groupby('a', dropna=True).transform(lambda x: x.sum())
86+
In [3]: # Value in the last row should be np.nan
87+
df.groupby('a', dropna=True).transform('sum')
8788
Out[3]:
8889
b
8990
0 5
9091
1 5
92+
2 5
9193
92-
In [3]: df.groupby('a', dropna=True).transform(lambda x: x)
94+
In [3]: # Should have one additional row with the value np.nan
95+
df.groupby('a', dropna=True).transform(lambda x: x.sum())
9396
Out[3]:
9497
b
95-
0 2
96-
1 3
98+
0 5
99+
1 5
100+
101+
In [3]: # The value in the last row is np.nan interpreted as an integer
102+
df.groupby('a', dropna=True).transform('ffill')
103+
Out[3]:
104+
b
105+
0 2
106+
1 3
107+
2 -9223372036854775808
97108
98-
In [3]: df.groupby('a', dropna=True).transform('sum')
109+
In [3]: # Should have one additional row with the value np.nan
110+
df.groupby('a', dropna=True).transform(lambda x: x)
99111
Out[3]:
100112
b
101-
0 5
102-
1 5
103-
2 5
113+
0 2
114+
1 3
104115
105116
*New behavior*:
106117

107118
.. ipython:: python
108119
120+
df.groupby('a', dropna=True).transform('sum')
109121
df.groupby('a', dropna=True).transform(lambda x: x.sum())
122+
df.groupby('a', dropna=True).transform('ffill')
110123
df.groupby('a', dropna=True).transform(lambda x: x)
111-
df.groupby('a', dropna=True).transform('sum')
112124
113125
.. _whatsnew_150.notable_bug_fixes.visualization:
114126

pandas/core/groupby/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ class OutputKey:
7070
"mean",
7171
"median",
7272
"min",
73-
"ngroup",
7473
"nth",
7574
"nunique",
7675
"prod",
@@ -113,6 +112,7 @@ def maybe_normalize_deprecated_kernels(kernel):
113112
"diff",
114113
"ffill",
115114
"fillna",
115+
"ngroup",
116116
"pad",
117117
"pct_change",
118118
"rank",

pandas/core/groupby/groupby.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class providing the base-class of operations.
6262
)
6363
from pandas.util._exceptions import find_stack_level
6464

65+
from pandas.core.dtypes.cast import ensure_dtype_can_hold_na
6566
from pandas.core.dtypes.common import (
6667
is_bool_dtype,
6768
is_datetime64_dtype,
@@ -950,7 +951,13 @@ def curried(x):
950951
if name in base.plotting_methods:
951952
return self.apply(curried)
952953

953-
return self._python_apply_general(curried, self._obj_with_exclusions)
954+
result = self._python_apply_general(curried, self._obj_with_exclusions)
955+
956+
if self.grouper.has_dropped_na and name in base.transformation_kernels:
957+
# result will have dropped rows due to nans, fill with null
958+
# and ensure index is ordered same as the input
959+
result = self._set_result_index_ordered(result)
960+
return result
954961

955962
wrapper.__name__ = name
956963
return wrapper
@@ -2608,7 +2615,11 @@ def blk_func(values: ArrayLike) -> ArrayLike:
26082615
# then there will be no -1s in indexer, so we can use
26092616
# the original dtype (no need to ensure_dtype_can_hold_na)
26102617
if isinstance(values, np.ndarray):
2611-
out = np.empty(values.shape, dtype=values.dtype)
2618+
dtype = values.dtype
2619+
if self.grouper.has_dropped_na:
2620+
# dropped null groups give rise to nan in the result
2621+
dtype = ensure_dtype_can_hold_na(values.dtype)
2622+
out = np.empty(values.shape, dtype=dtype)
26122623
else:
26132624
out = type(values)._empty(values.shape, dtype=values.dtype)
26142625

@@ -3114,9 +3125,16 @@ def ngroup(self, ascending: bool = True):
31143125
"""
31153126
with self._group_selection_context():
31163127
index = self._selected_obj.index
3117-
result = self._obj_1d_constructor(
3118-
self.grouper.group_info[0], index, dtype=np.int64
3119-
)
3128+
comp_ids = self.grouper.group_info[0]
3129+
3130+
dtype: type
3131+
if self.grouper.has_dropped_na:
3132+
comp_ids = np.where(comp_ids == -1, np.nan, comp_ids)
3133+
dtype = np.float64
3134+
else:
3135+
dtype = np.int64
3136+
3137+
result = self._obj_1d_constructor(comp_ids, index, dtype=dtype)
31203138
if not ascending:
31213139
result = self.ngroups - 1 - result
31223140
return result

pandas/core/groupby/ops.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,25 @@
9898
class WrappedCythonOp:
9999
"""
100100
Dispatch logic for functions defined in _libs.groupby
101+
102+
Parameters
103+
----------
104+
kind: str
105+
Whether the operation is an aggregate or transform.
106+
how: str
107+
Operation name, e.g. "mean".
108+
has_dropped_na: bool
109+
True precisely when dropna=True and the grouper contains a null value.
101110
"""
102111

103112
# Functions for which we do _not_ attempt to cast the cython result
104113
# back to the original dtype.
105114
cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
106115

107-
def __init__(self, kind: str, how: str) -> None:
116+
def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
108117
self.kind = kind
109118
self.how = how
119+
self.has_dropped_na = has_dropped_na
110120

111121
_CYTHON_FUNCTIONS = {
112122
"aggregate": {
@@ -194,7 +204,9 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
194204
values = ensure_float64(values)
195205

196206
elif values.dtype.kind in ["i", "u"]:
197-
if how in ["add", "var", "prod", "mean", "ohlc"]:
207+
if how in ["add", "var", "prod", "mean", "ohlc"] or (
208+
self.kind == "transform" and self.has_dropped_na
209+
):
198210
# result may still include NaN, so we have to cast
199211
values = ensure_float64(values)
200212

@@ -582,6 +594,10 @@ def _call_cython_op(
582594

583595
result = result.T
584596

597+
if self.how == "rank" and self.has_dropped_na:
598+
# TODO: Wouldn't need this if group_rank supported mask
599+
result = np.where(comp_ids < 0, np.nan, result)
600+
585601
if self.how not in self.cast_blocklist:
586602
# e.g. if we are int64 and need to restore to datetime64/timedelta64
587603
# "rank" is the only member of cast_blocklist we get here
@@ -959,7 +975,7 @@ def _cython_operation(
959975
"""
960976
assert kind in ["transform", "aggregate"]
961977

962-
cy_op = WrappedCythonOp(kind=kind, how=how)
978+
cy_op = WrappedCythonOp(kind=kind, how=how, has_dropped_na=self.has_dropped_na)
963979

964980
ids, _, _ = self.group_info
965981
ngroups = self.ngroups

pandas/tests/apply/test_frame_transform.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def test_transform_bad_dtype(op, frame_or_series, request):
139139
raises=ValueError, reason="GH 40418: rank does not raise a TypeError"
140140
)
141141
)
142+
elif op == "ngroup":
143+
request.node.add_marker(
144+
pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame")
145+
)
142146

143147
obj = DataFrame({"A": 3 * [object]}) # DataFrame that will fail on most transforms
144148
obj = tm.get_obj(obj, frame_or_series)
@@ -157,9 +161,14 @@ def test_transform_bad_dtype(op, frame_or_series, request):
157161

158162

159163
@pytest.mark.parametrize("op", frame_kernels_raise)
160-
def test_transform_partial_failure_typeerror(op):
164+
def test_transform_partial_failure_typeerror(request, op):
161165
# GH 35964
162166

167+
if op == "ngroup":
168+
request.node.add_marker(
169+
pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame")
170+
)
171+
163172
# Using object makes most transform kernels fail
164173
df = DataFrame({"A": 3 * [object], "B": [1, 2, 3]})
165174

pandas/tests/apply/test_str.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,12 @@ def test_agg_cython_table_transform_frame(df, func, expected, axis):
243243

244244

245245
@pytest.mark.parametrize("op", series_transform_kernels)
246-
def test_transform_groupby_kernel_series(string_series, op):
246+
def test_transform_groupby_kernel_series(request, string_series, op):
247247
# GH 35964
248+
if op == "ngroup":
249+
request.node.add_marker(
250+
pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame")
251+
)
248252
# TODO(2.0) Remove after pad/backfill deprecation enforced
249253
op = maybe_normalize_deprecated_kernels(op)
250254
args = [0.0] if op == "fillna" else []
@@ -255,9 +259,15 @@ def test_transform_groupby_kernel_series(string_series, op):
255259

256260

257261
@pytest.mark.parametrize("op", frame_transform_kernels)
258-
def test_transform_groupby_kernel_frame(axis, float_frame, op):
262+
def test_transform_groupby_kernel_frame(request, axis, float_frame, op):
259263
# TODO(2.0) Remove after pad/backfill deprecation enforced
260264
op = maybe_normalize_deprecated_kernels(op)
265+
266+
if op == "ngroup":
267+
request.node.add_marker(
268+
pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame")
269+
)
270+
261271
# GH 35964
262272

263273
args = [0.0] if op == "fillna" else []

pandas/tests/groupby/test_rank.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,8 @@ def test_non_unique_index():
658658
)
659659
result = df.groupby([df.index, "A"]).value.rank(ascending=True, pct=True)
660660
expected = Series(
661-
[1.0] * 4, index=[pd.Timestamp("20170101", tz="US/Eastern")] * 4, name="value"
661+
[1.0, 1.0, 1.0, np.nan],
662+
index=[pd.Timestamp("20170101", tz="US/Eastern")] * 4,
663+
name="value",
662664
)
663665
tm.assert_series_equal(result, expected)

pandas/tests/groupby/transform/test_transform.py

+31-45
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ def test_transform_axis_1(request, transformation_func):
168168
# TODO(2.0) Remove after pad/backfill deprecation enforced
169169
transformation_func = maybe_normalize_deprecated_kernels(transformation_func)
170170

171+
if transformation_func == "ngroup":
172+
msg = "ngroup fails with axis=1: #45986"
173+
request.node.add_marker(pytest.mark.xfail(reason=msg))
174+
171175
warn = None
172176
if transformation_func == "tshift":
173177
warn = FutureWarning
@@ -383,6 +387,15 @@ def test_transform_transformation_func(request, transformation_func):
383387
elif transformation_func == "fillna":
384388
test_op = lambda x: x.transform("fillna", value=0)
385389
mock_op = lambda x: x.fillna(value=0)
390+
elif transformation_func == "ngroup":
391+
test_op = lambda x: x.transform("ngroup")
392+
counter = -1
393+
394+
def mock_op(x):
395+
nonlocal counter
396+
counter += 1
397+
return Series(counter, index=x.index)
398+
386399
elif transformation_func == "tshift":
387400
msg = (
388401
"Current behavior of groupby.tshift is inconsistent with other "
@@ -394,10 +407,14 @@ def test_transform_transformation_func(request, transformation_func):
394407
mock_op = lambda x: getattr(x, transformation_func)()
395408

396409
result = test_op(df.groupby("A"))
397-
groups = [df[["B"]].iloc[:4], df[["B"]].iloc[4:6], df[["B"]].iloc[6:]]
398-
expected = concat([mock_op(g) for g in groups])
410+
# pass the group in same order as iterating `for ... in df.groupby(...)`
411+
# but reorder to match df's index since this is a transform
412+
groups = [df[["B"]].iloc[4:6], df[["B"]].iloc[6:], df[["B"]].iloc[:4]]
413+
expected = concat([mock_op(g) for g in groups]).sort_index()
414+
# sort_index does not preserve the freq
415+
expected = expected.set_axis(df.index)
399416

400-
if transformation_func == "cumcount":
417+
if transformation_func in ("cumcount", "ngroup"):
401418
tm.assert_series_equal(result, expected)
402419
else:
403420
tm.assert_frame_equal(result, expected)
@@ -1122,10 +1139,6 @@ def test_transform_agg_by_name(request, reduction_func, obj):
11221139
func = reduction_func
11231140
g = obj.groupby(np.repeat([0, 1], 3))
11241141

1125-
if func == "ngroup": # GH#27468
1126-
request.node.add_marker(
1127-
pytest.mark.xfail(reason="TODO: g.transform('ngroup') doesn't work")
1128-
)
11291142
if func == "corrwith" and isinstance(obj, Series): # GH#32293
11301143
request.node.add_marker(
11311144
pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith")
@@ -1137,8 +1150,8 @@ def test_transform_agg_by_name(request, reduction_func, obj):
11371150
# this is the *definition* of a transformation
11381151
tm.assert_index_equal(result.index, obj.index)
11391152

1140-
if func != "size" and obj.ndim == 2:
1141-
# size returns a Series, unlike other transforms
1153+
if func not in ("ngroup", "size") and obj.ndim == 2:
1154+
# size/ngroup return a Series, unlike other transforms
11421155
tm.assert_index_equal(result.columns, obj.columns)
11431156

11441157
# verify that values were broadcasted across each group
@@ -1312,7 +1325,7 @@ def test_null_group_lambda_self(sort, dropna):
13121325

13131326
def test_null_group_str_reducer(request, dropna, reduction_func):
13141327
# GH 17093
1315-
if reduction_func in ("corrwith", "ngroup"):
1328+
if reduction_func == "corrwith":
13161329
msg = "incorrectly raises"
13171330
request.node.add_marker(pytest.mark.xfail(reason=msg))
13181331
index = [1, 2, 3, 4] # test transform preserves non-standard index
@@ -1358,31 +1371,11 @@ def test_null_group_str_reducer(request, dropna, reduction_func):
13581371

13591372

13601373
@pytest.mark.filterwarnings("ignore:tshift is deprecated:FutureWarning")
1361-
def test_null_group_str_transformer(
1362-
request, using_array_manager, dropna, transformation_func
1363-
):
1374+
def test_null_group_str_transformer(request, dropna, transformation_func):
13641375
# GH 17093
1365-
xfails_block = (
1366-
"cummax",
1367-
"cummin",
1368-
"cumsum",
1369-
"fillna",
1370-
"rank",
1371-
"backfill",
1372-
"ffill",
1373-
"bfill",
1374-
"pad",
1375-
)
1376-
xfails_array = ("cummax", "cummin", "cumsum", "fillna", "rank")
13771376
if transformation_func == "tshift":
13781377
msg = "tshift requires timeseries"
13791378
request.node.add_marker(pytest.mark.xfail(reason=msg))
1380-
elif dropna and (
1381-
(not using_array_manager and transformation_func in xfails_block)
1382-
or (using_array_manager and transformation_func in xfails_array)
1383-
):
1384-
msg = "produces incorrect results when nans are present"
1385-
request.node.add_marker(pytest.mark.xfail(reason=msg))
13861379
args = (0,) if transformation_func == "fillna" else ()
13871380
df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3])
13881381
gb = df.groupby("A", dropna=dropna)
@@ -1420,10 +1413,6 @@ def test_null_group_str_reducer_series(request, dropna, reduction_func):
14201413
msg = "corrwith not implemented for SeriesGroupBy"
14211414
request.node.add_marker(pytest.mark.xfail(reason=msg))
14221415

1423-
if reduction_func == "ngroup":
1424-
msg = "ngroup fails"
1425-
request.node.add_marker(pytest.mark.xfail(reason=msg))
1426-
14271416
# GH 17093
14281417
index = [1, 2, 3, 4] # test transform preserves non-standard index
14291418
ser = Series([1, 2, 2, 3], index=index)
@@ -1470,15 +1459,6 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func)
14701459
if transformation_func == "tshift":
14711460
msg = "tshift requires timeseries"
14721461
request.node.add_marker(pytest.mark.xfail(reason=msg))
1473-
elif dropna and transformation_func in (
1474-
"cummax",
1475-
"cummin",
1476-
"cumsum",
1477-
"fillna",
1478-
"rank",
1479-
):
1480-
msg = "produces incorrect results when nans are present"
1481-
request.node.add_marker(pytest.mark.xfail(reason=msg))
14821462
args = (0,) if transformation_func == "fillna" else ()
14831463
ser = Series([1, 2, 2], index=[1, 2, 3])
14841464
gb = ser.groupby([1, 1, np.nan], dropna=dropna)
@@ -1502,4 +1482,10 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func)
15021482
msg = f"{transformation_func} is deprecated"
15031483
with tm.assert_produces_warning(warn, match=msg):
15041484
result = gb.transform(transformation_func, *args)
1505-
tm.assert_equal(result, expected)
1485+
if dropna and transformation_func == "fillna":
1486+
# GH#46369 - result name is the group; remove this block when fixed.
1487+
tm.assert_equal(result, expected, check_names=False)
1488+
# This should be None
1489+
assert result.name == 1.0
1490+
else:
1491+
tm.assert_equal(result, expected)

0 commit comments

Comments
 (0)