Skip to content

Commit 42bf375

Browse files
authored
ENH: Support skipna parameter in GroupBy mean and sum (#60741)
* ENH: Support skipna parameter in GroupBy mean and sum * Move numba tests to test_numba.py * Fix docstring and failing future string test
1 parent 7234104 commit 42bf375

File tree

9 files changed

+255
-10
lines changed

9 files changed

+255
-10
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Other enhancements
6060
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
6161
- :class:`Rolling` and :class:`Expanding` now support aggregations ``first`` and ``last`` (:issue:`33155`)
6262
- :func:`read_parquet` accepts ``to_pandas_kwargs`` which are forwarded to :meth:`pyarrow.Table.to_pandas` which enables passing additional keywords to customize the conversion to pandas, such as ``maps_as_pydicts`` to read the Parquet map data type as python dictionaries (:issue:`56842`)
63+
- :meth:`.DataFrameGroupBy.mean`, :meth:`.DataFrameGroupBy.sum`, :meth:`.SeriesGroupBy.mean` and :meth:`.SeriesGroupBy.sum` now accept ``skipna`` parameter (:issue:`15675`)
6364
- :meth:`.DataFrameGroupBy.transform`, :meth:`.SeriesGroupBy.transform`, :meth:`.DataFrameGroupBy.agg`, :meth:`.SeriesGroupBy.agg`, :meth:`.SeriesGroupBy.apply`, :meth:`.DataFrameGroupBy.apply` now support ``kurt`` (:issue:`40139`)
6465
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
6566
- :meth:`Rolling.agg`, :meth:`Expanding.agg` and :meth:`ExponentialMovingWindow.agg` now accept :class:`NamedAgg` aggregations through ``**kwargs`` (:issue:`28333`)

pandas/_libs/groupby.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def group_sum(
6666
result_mask: np.ndarray | None = ...,
6767
min_count: int = ...,
6868
is_datetimelike: bool = ...,
69+
skipna: bool = ...,
6970
) -> None: ...
7071
def group_prod(
7172
out: np.ndarray, # int64float_t[:, ::1]
@@ -115,6 +116,7 @@ def group_mean(
115116
is_datetimelike: bool = ..., # bint
116117
mask: np.ndarray | None = ...,
117118
result_mask: np.ndarray | None = ...,
119+
skipna: bool = ...,
118120
) -> None: ...
119121
def group_ohlc(
120122
out: np.ndarray, # floatingintuint_t[:, ::1]

pandas/_libs/groupby.pyx

+47-1
Original file line numberDiff line numberDiff line change
@@ -700,13 +700,14 @@ def group_sum(
700700
uint8_t[:, ::1] result_mask=None,
701701
Py_ssize_t min_count=0,
702702
bint is_datetimelike=False,
703+
bint skipna=True,
703704
) -> None:
704705
"""
705706
Only aggregates on axis=0 using Kahan summation
706707
"""
707708
cdef:
708709
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
709-
sum_t val, t, y
710+
sum_t val, t, y, nan_val
710711
sum_t[:, ::1] sumx, compensation
711712
int64_t[:, ::1] nobs
712713
Py_ssize_t len_values = len(values), len_labels = len(labels)
@@ -722,6 +723,15 @@ def group_sum(
722723
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
723724

724725
N, K = (<object>values).shape
726+
if uses_mask:
727+
nan_val = 0
728+
elif is_datetimelike:
729+
nan_val = NPY_NAT
730+
elif sum_t is int64_t or sum_t is uint64_t:
731+
# This has no effect as int64 can't be nan. Setting to 0 to avoid type error
732+
nan_val = 0
733+
else:
734+
nan_val = NAN
725735

726736
with nogil(sum_t is not object):
727737
for i in range(N):
@@ -734,6 +744,16 @@ def group_sum(
734744
for j in range(K):
735745
val = values[i, j]
736746

747+
if not skipna and (
748+
(uses_mask and result_mask[lab, j]) or
749+
(is_datetimelike and sumx[lab, j] == NPY_NAT) or
750+
_treat_as_na(sumx[lab, j], False)
751+
):
752+
# If sum is already NA, don't add to it. This is important for
753+
# datetimelikebecause adding a value to NPY_NAT may not result
754+
# in a NPY_NAT
755+
continue
756+
737757
if uses_mask:
738758
isna_entry = mask[i, j]
739759
else:
@@ -765,6 +785,11 @@ def group_sum(
765785
# because of no gil
766786
compensation[lab, j] = 0
767787
sumx[lab, j] = t
788+
elif not skipna:
789+
if uses_mask:
790+
result_mask[lab, j] = True
791+
else:
792+
sumx[lab, j] = nan_val
768793

769794
_check_below_mincount(
770795
out, uses_mask, result_mask, ncounts, K, nobs, min_count, sumx
@@ -1100,6 +1125,7 @@ def group_mean(
11001125
bint is_datetimelike=False,
11011126
const uint8_t[:, ::1] mask=None,
11021127
uint8_t[:, ::1] result_mask=None,
1128+
bint skipna=True,
11031129
) -> None:
11041130
"""
11051131
Compute the mean per label given a label assignment for each value.
@@ -1125,6 +1151,8 @@ def group_mean(
11251151
Mask of the input values.
11261152
result_mask : ndarray[bool, ndim=2], optional
11271153
Mask of the out array
1154+
skipna : bool, optional
1155+
If True, ignore nans in `values`.
11281156

11291157
Notes
11301158
-----
@@ -1168,6 +1196,16 @@ def group_mean(
11681196
for j in range(K):
11691197
val = values[i, j]
11701198

1199+
if not skipna and (
1200+
(uses_mask and result_mask[lab, j]) or
1201+
(is_datetimelike and sumx[lab, j] == NPY_NAT) or
1202+
_treat_as_na(sumx[lab, j], False)
1203+
):
1204+
# If sum is already NA, don't add to it. This is important for
1205+
# datetimelike because adding a value to NPY_NAT may not result
1206+
# in NPY_NAT
1207+
continue
1208+
11711209
if uses_mask:
11721210
isna_entry = mask[i, j]
11731211
elif is_datetimelike:
@@ -1191,6 +1229,14 @@ def group_mean(
11911229
# because of no gil
11921230
compensation[lab, j] = 0.
11931231
sumx[lab, j] = t
1232+
elif not skipna:
1233+
# Set the nobs to 0 so that in case of datetimelike,
1234+
# dividing NPY_NAT by nobs may not result in a NPY_NAT
1235+
nobs[lab, j] = 0
1236+
if uses_mask:
1237+
result_mask[lab, j] = True
1238+
else:
1239+
sumx[lab, j] = nan_val
11941240

11951241
for i in range(ncounts):
11961242
for j in range(K):

pandas/core/_numba/kernels/mean_.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,10 @@ def grouped_mean(
169169
labels: npt.NDArray[np.intp],
170170
ngroups: int,
171171
min_periods: int,
172+
skipna: bool,
172173
) -> tuple[np.ndarray, list[int]]:
173174
output, nobs_arr, comp_arr, consecutive_counts, prev_vals = grouped_kahan_sum(
174-
values, result_dtype, labels, ngroups
175+
values, result_dtype, labels, ngroups, skipna
175176
)
176177

177178
# Post-processing, replace sums that don't satisfy min_periods

pandas/core/_numba/kernels/sum_.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def grouped_kahan_sum(
165165
result_dtype: np.dtype,
166166
labels: npt.NDArray[np.intp],
167167
ngroups: int,
168+
skipna: bool,
168169
) -> tuple[
169170
np.ndarray, npt.NDArray[np.int64], np.ndarray, npt.NDArray[np.int64], np.ndarray
170171
]:
@@ -180,7 +181,15 @@ def grouped_kahan_sum(
180181
lab = labels[i]
181182
val = values[i]
182183

183-
if lab < 0:
184+
if lab < 0 or np.isnan(output[lab]):
185+
continue
186+
187+
if not skipna and np.isnan(val):
188+
output[lab] = np.nan
189+
nobs_arr[lab] += 1
190+
comp_arr[lab] = np.nan
191+
consecutive_counts[lab] = 1
192+
prev_vals[lab] = np.nan
184193
continue
185194

186195
sum_x = output[lab]
@@ -219,11 +228,12 @@ def grouped_sum(
219228
labels: npt.NDArray[np.intp],
220229
ngroups: int,
221230
min_periods: int,
231+
skipna: bool,
222232
) -> tuple[np.ndarray, list[int]]:
223233
na_pos = []
224234

225235
output, nobs_arr, comp_arr, consecutive_counts, prev_vals = grouped_kahan_sum(
226-
values, result_dtype, labels, ngroups
236+
values, result_dtype, labels, ngroups, skipna
227237
)
228238

229239
# Post-processing, replace sums that don't satisfy min_periods

pandas/core/groupby/groupby.py

+72-2
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,61 @@ class providing the base-class of operations.
214214
{example}
215215
"""
216216

217+
_groupby_agg_method_skipna_engine_template = """
218+
Compute {fname} of group values.
219+
220+
Parameters
221+
----------
222+
numeric_only : bool, default {no}
223+
Include only float, int, boolean columns.
224+
225+
.. versionchanged:: 2.0.0
226+
227+
numeric_only no longer accepts ``None``.
228+
229+
min_count : int, default {mc}
230+
The required number of valid values to perform the operation. If fewer
231+
than ``min_count`` non-NA values are present the result will be NA.
232+
233+
skipna : bool, default {s}
234+
Exclude NA/null values. If the entire group is NA and ``skipna`` is
235+
``True``, the result will be NA.
236+
237+
.. versionchanged:: 3.0.0
238+
239+
engine : str, default None {e}
240+
* ``'cython'`` : Runs rolling apply through C-extensions from cython.
241+
* ``'numba'`` : Runs rolling apply through JIT compiled code from numba.
242+
Only available when ``raw`` is set to ``True``.
243+
* ``None`` : Defaults to ``'cython'`` or globally setting ``compute.use_numba``
244+
245+
engine_kwargs : dict, default None {ek}
246+
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
247+
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
248+
and ``parallel`` dictionary keys. The values must either be ``True`` or
249+
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
250+
``{{'nopython': True, 'nogil': False, 'parallel': False}}`` and will be
251+
applied to both the ``func`` and the ``apply`` groupby aggregation.
252+
253+
Returns
254+
-------
255+
Series or DataFrame
256+
Computed {fname} of values within each group.
257+
258+
See Also
259+
--------
260+
SeriesGroupBy.min : Return the min of the group values.
261+
DataFrameGroupBy.min : Return the min of the group values.
262+
SeriesGroupBy.max : Return the max of the group values.
263+
DataFrameGroupBy.max : Return the max of the group values.
264+
SeriesGroupBy.sum : Return the sum of the group values.
265+
DataFrameGroupBy.sum : Return the sum of the group values.
266+
267+
Examples
268+
--------
269+
{example}
270+
"""
271+
217272
_pipe_template = """
218273
Apply a ``func`` with arguments to this %(klass)s object and return its result.
219274
@@ -2091,6 +2146,7 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
20912146
def mean(
20922147
self,
20932148
numeric_only: bool = False,
2149+
skipna: bool = True,
20942150
engine: Literal["cython", "numba"] | None = None,
20952151
engine_kwargs: dict[str, bool] | None = None,
20962152
):
@@ -2106,6 +2162,12 @@ def mean(
21062162
21072163
numeric_only no longer accepts ``None`` and defaults to ``False``.
21082164
2165+
skipna : bool, default True
2166+
Exclude NA/null values. If an entire row/column is NA, the result
2167+
will be NA.
2168+
2169+
.. versionadded:: 3.0.0
2170+
21092171
engine : str, default None
21102172
* ``'cython'`` : Runs the operation through C-extensions from cython.
21112173
* ``'numba'`` : Runs the operation through JIT compiled code from numba.
@@ -2172,12 +2234,16 @@ def mean(
21722234
executor.float_dtype_mapping,
21732235
engine_kwargs,
21742236
min_periods=0,
2237+
skipna=skipna,
21752238
)
21762239
else:
21772240
result = self._cython_agg_general(
21782241
"mean",
2179-
alt=lambda x: Series(x, copy=False).mean(numeric_only=numeric_only),
2242+
alt=lambda x: Series(x, copy=False).mean(
2243+
numeric_only=numeric_only, skipna=skipna
2244+
),
21802245
numeric_only=numeric_only,
2246+
skipna=skipna,
21812247
)
21822248
return result.__finalize__(self.obj, method="groupby")
21832249

@@ -2817,10 +2883,11 @@ def size(self) -> DataFrame | Series:
28172883

28182884
@final
28192885
@doc(
2820-
_groupby_agg_method_engine_template,
2886+
_groupby_agg_method_skipna_engine_template,
28212887
fname="sum",
28222888
no=False,
28232889
mc=0,
2890+
s=True,
28242891
e=None,
28252892
ek=None,
28262893
example=dedent(
@@ -2862,6 +2929,7 @@ def sum(
28622929
self,
28632930
numeric_only: bool = False,
28642931
min_count: int = 0,
2932+
skipna: bool = True,
28652933
engine: Literal["cython", "numba"] | None = None,
28662934
engine_kwargs: dict[str, bool] | None = None,
28672935
):
@@ -2873,6 +2941,7 @@ def sum(
28732941
executor.default_dtype_mapping,
28742942
engine_kwargs,
28752943
min_periods=min_count,
2944+
skipna=skipna,
28762945
)
28772946
else:
28782947
# If we are grouping on categoricals we want unobserved categories to
@@ -2884,6 +2953,7 @@ def sum(
28842953
min_count=min_count,
28852954
alias="sum",
28862955
npfunc=np.sum,
2956+
skipna=skipna,
28872957
)
28882958

28892959
return result

pandas/tests/groupby/aggregate/test_numba.py

+17
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,23 @@ def test_multifunc_numba_vs_cython_frame(agg_kwargs):
186186
tm.assert_frame_equal(result, expected)
187187

188188

189+
@pytest.mark.parametrize("func", ["sum", "mean"])
190+
def test_multifunc_numba_vs_cython_frame_noskipna(func):
191+
pytest.importorskip("numba")
192+
data = DataFrame(
193+
{
194+
0: ["a", "a", "b", "b", "a"],
195+
1: [1.0, np.nan, 3.0, 4.0, 5.0],
196+
2: [1, 2, 3, 4, 5],
197+
},
198+
columns=[0, 1, 2],
199+
)
200+
grouped = data.groupby(0)
201+
result = grouped.agg(func, skipna=False, engine="numba")
202+
expected = grouped.agg(func, skipna=False, engine="cython")
203+
tm.assert_frame_equal(result, expected)
204+
205+
189206
@pytest.mark.parametrize(
190207
"agg_kwargs,expected_func",
191208
[

pandas/tests/groupby/test_api.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,10 @@ def test_frame_consistency(groupby_func):
176176
elif groupby_func in ("max", "min"):
177177
exclude_expected = {"axis", "kwargs", "skipna"}
178178
exclude_result = {"min_count", "engine", "engine_kwargs"}
179-
elif groupby_func in ("mean", "std", "sum", "var"):
179+
elif groupby_func in ("sum", "mean"):
180+
exclude_expected = {"axis", "kwargs"}
181+
exclude_result = {"engine", "engine_kwargs"}
182+
elif groupby_func in ("std", "var"):
180183
exclude_expected = {"axis", "kwargs", "skipna"}
181184
exclude_result = {"engine", "engine_kwargs"}
182185
elif groupby_func in ("median", "prod", "sem"):
@@ -234,7 +237,10 @@ def test_series_consistency(request, groupby_func):
234237
elif groupby_func in ("max", "min"):
235238
exclude_expected = {"axis", "kwargs", "skipna"}
236239
exclude_result = {"min_count", "engine", "engine_kwargs"}
237-
elif groupby_func in ("mean", "std", "sum", "var"):
240+
elif groupby_func in ("sum", "mean"):
241+
exclude_expected = {"axis", "kwargs"}
242+
exclude_result = {"engine", "engine_kwargs"}
243+
elif groupby_func in ("std", "var"):
238244
exclude_expected = {"axis", "kwargs", "skipna"}
239245
exclude_result = {"engine", "engine_kwargs"}
240246
elif groupby_func in ("median", "prod", "sem"):

0 commit comments

Comments
 (0)