Skip to content

Commit a68048e

Browse files
authored
ENH: Support skipna parameter in GroupBy min, max, prod, median, var, std and sem methods (#60752)
1 parent f1441b2 commit a68048e

File tree

10 files changed

+405
-51
lines changed

10 files changed

+405
-51
lines changed

doc/source/whatsnew/v3.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ Other enhancements
5959
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
6060
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
6161
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
62+
- :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` methods ``sum``, ``mean``, ``median``, ``prod``, ``min``, ``max``, ``std``, ``var`` and ``sem`` now accept ``skipna`` parameter (:issue:`15675`)
6263
- :class:`Rolling` and :class:`Expanding` now support aggregations ``first`` and ``last`` (:issue:`33155`)
6364
- :func:`read_parquet` accepts ``to_pandas_kwargs`` which are forwarded to :meth:`pyarrow.Table.to_pandas` which enables passing additional keywords to customize the conversion to pandas, such as ``maps_as_pydicts`` to read the Parquet map data type as python dictionaries (:issue:`56842`)
64-
- :meth:`.DataFrameGroupBy.mean`, :meth:`.DataFrameGroupBy.sum`, :meth:`.SeriesGroupBy.mean` and :meth:`.SeriesGroupBy.sum` now accept ``skipna`` parameter (:issue:`15675`)
6565
- :meth:`.DataFrameGroupBy.transform`, :meth:`.SeriesGroupBy.transform`, :meth:`.DataFrameGroupBy.agg`, :meth:`.SeriesGroupBy.agg`, :meth:`.SeriesGroupBy.apply`, :meth:`.DataFrameGroupBy.apply` now support ``kurt`` (:issue:`40139`)
6666
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
6767
- :meth:`Rolling.agg`, :meth:`Expanding.agg` and :meth:`ExponentialMovingWindow.agg` now accept :class:`NamedAgg` aggregations through ``**kwargs`` (:issue:`28333`)

pandas/_libs/groupby.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def group_median_float64(
1313
mask: np.ndarray | None = ...,
1414
result_mask: np.ndarray | None = ...,
1515
is_datetimelike: bool = ..., # bint
16+
skipna: bool = ...,
1617
) -> None: ...
1718
def group_cumprod(
1819
out: np.ndarray, # float64_t[:, ::1]
@@ -76,6 +77,7 @@ def group_prod(
7677
mask: np.ndarray | None,
7778
result_mask: np.ndarray | None = ...,
7879
min_count: int = ...,
80+
skipna: bool = ...,
7981
) -> None: ...
8082
def group_var(
8183
out: np.ndarray, # floating[:, ::1]
@@ -88,6 +90,7 @@ def group_var(
8890
result_mask: np.ndarray | None = ...,
8991
is_datetimelike: bool = ...,
9092
name: str = ...,
93+
skipna: bool = ...,
9194
) -> None: ...
9295
def group_skew(
9396
out: np.ndarray, # float64_t[:, ::1]
@@ -183,6 +186,7 @@ def group_max(
183186
is_datetimelike: bool = ...,
184187
mask: np.ndarray | None = ...,
185188
result_mask: np.ndarray | None = ...,
189+
skipna: bool = ...,
186190
) -> None: ...
187191
def group_min(
188192
out: np.ndarray, # groupby_t[:, ::1]
@@ -193,6 +197,7 @@ def group_min(
193197
is_datetimelike: bool = ...,
194198
mask: np.ndarray | None = ...,
195199
result_mask: np.ndarray | None = ...,
200+
skipna: bool = ...,
196201
) -> None: ...
197202
def group_idxmin_idxmax(
198203
out: npt.NDArray[np.intp],

pandas/_libs/groupby.pyx

+77-22
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ cdef enum InterpolationEnumType:
6262
INTERPOLATION_MIDPOINT
6363

6464

65-
cdef float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) noexcept nogil:
65+
cdef float64_t median_linear_mask(
66+
float64_t* a,
67+
int n,
68+
uint8_t* mask,
69+
bint skipna=True
70+
) noexcept nogil:
6671
cdef:
6772
int i, j, na_count = 0
6873
float64_t* tmp
@@ -77,7 +82,7 @@ cdef float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) noexcept n
7782
na_count += 1
7883

7984
if na_count:
80-
if na_count == n:
85+
if na_count == n or not skipna:
8186
return NaN
8287

8388
tmp = <float64_t*>malloc((n - na_count) * sizeof(float64_t))
@@ -104,7 +109,8 @@ cdef float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) noexcept n
104109
cdef float64_t median_linear(
105110
float64_t* a,
106111
int n,
107-
bint is_datetimelike=False
112+
bint is_datetimelike=False,
113+
bint skipna=True,
108114
) noexcept nogil:
109115
cdef:
110116
int i, j, na_count = 0
@@ -125,7 +131,7 @@ cdef float64_t median_linear(
125131
na_count += 1
126132

127133
if na_count:
128-
if na_count == n:
134+
if na_count == n or not skipna:
129135
return NaN
130136

131137
tmp = <float64_t*>malloc((n - na_count) * sizeof(float64_t))
@@ -186,6 +192,7 @@ def group_median_float64(
186192
const uint8_t[:, :] mask=None,
187193
uint8_t[:, ::1] result_mask=None,
188194
bint is_datetimelike=False,
195+
bint skipna=True,
189196
) -> None:
190197
"""
191198
Only aggregates on axis=0
@@ -229,7 +236,7 @@ def group_median_float64(
229236

230237
for j in range(ngroups):
231238
size = _counts[j + 1]
232-
result = median_linear_mask(ptr, size, ptr_mask)
239+
result = median_linear_mask(ptr, size, ptr_mask, skipna)
233240
out[j, i] = result
234241

235242
if result != result:
@@ -244,7 +251,7 @@ def group_median_float64(
244251
ptr += _counts[0]
245252
for j in range(ngroups):
246253
size = _counts[j + 1]
247-
out[j, i] = median_linear(ptr, size, is_datetimelike)
254+
out[j, i] = median_linear(ptr, size, is_datetimelike, skipna)
248255
ptr += size
249256

250257

@@ -804,17 +811,18 @@ def group_prod(
804811
const uint8_t[:, ::1] mask,
805812
uint8_t[:, ::1] result_mask=None,
806813
Py_ssize_t min_count=0,
814+
bint skipna=True,
807815
) -> None:
808816
"""
809817
Only aggregates on axis=0
810818
"""
811819
cdef:
812820
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
813-
int64float_t val
821+
int64float_t val, nan_val
814822
int64float_t[:, ::1] prodx
815823
int64_t[:, ::1] nobs
816824
Py_ssize_t len_values = len(values), len_labels = len(labels)
817-
bint isna_entry, uses_mask = mask is not None
825+
bint isna_entry, isna_result, uses_mask = mask is not None
818826

819827
if len_values != len_labels:
820828
raise ValueError("len(index) != len(labels)")
@@ -823,6 +831,7 @@ def group_prod(
823831
prodx = np.ones((<object>out).shape, dtype=(<object>out).base.dtype)
824832

825833
N, K = (<object>values).shape
834+
nan_val = _get_na_val(<int64float_t>0, False)
826835

827836
with nogil:
828837
for i in range(N):
@@ -836,12 +845,23 @@ def group_prod(
836845

837846
if uses_mask:
838847
isna_entry = mask[i, j]
848+
isna_result = result_mask[lab, j]
839849
else:
840850
isna_entry = _treat_as_na(val, False)
851+
isna_result = _treat_as_na(prodx[lab, j], False)
852+
853+
if not skipna and isna_result:
854+
# If prod is already NA, no need to update it
855+
continue
841856

842857
if not isna_entry:
843858
nobs[lab, j] += 1
844859
prodx[lab, j] *= val
860+
elif not skipna:
861+
if uses_mask:
862+
result_mask[lab, j] = True
863+
else:
864+
prodx[lab, j] = nan_val
845865

846866
_check_below_mincount(
847867
out, uses_mask, result_mask, ncounts, K, nobs, min_count, prodx
@@ -862,14 +882,15 @@ def group_var(
862882
uint8_t[:, ::1] result_mask=None,
863883
bint is_datetimelike=False,
864884
str name="var",
885+
bint skipna=True,
865886
) -> None:
866887
cdef:
867888
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
868889
floating val, ct, oldmean
869890
floating[:, ::1] mean
870891
int64_t[:, ::1] nobs
871892
Py_ssize_t len_values = len(values), len_labels = len(labels)
872-
bint isna_entry, uses_mask = mask is not None
893+
bint isna_entry, isna_result, uses_mask = mask is not None
873894
bint is_std = name == "std"
874895
bint is_sem = name == "sem"
875896

@@ -898,19 +919,34 @@ def group_var(
898919

899920
if uses_mask:
900921
isna_entry = mask[i, j]
922+
isna_result = result_mask[lab, j]
901923
elif is_datetimelike:
902924
# With group_var, we cannot just use _treat_as_na bc
903925
# datetimelike dtypes get cast to float64 instead of
904926
# to int64.
905927
isna_entry = val == NPY_NAT
928+
isna_result = out[lab, j] == NPY_NAT
906929
else:
907930
isna_entry = _treat_as_na(val, is_datetimelike)
931+
isna_result = _treat_as_na(out[lab, j], is_datetimelike)
932+
933+
if not skipna and isna_result:
934+
# If aggregate is already NA, don't add to it. This is important for
935+
# datetimelike because adding a value to NPY_NAT may not result
936+
# in a NPY_NAT
937+
continue
908938

909939
if not isna_entry:
910940
nobs[lab, j] += 1
911941
oldmean = mean[lab, j]
912942
mean[lab, j] += (val - oldmean) / nobs[lab, j]
913943
out[lab, j] += (val - mean[lab, j]) * (val - oldmean)
944+
elif not skipna:
945+
nobs[lab, j] = 0
946+
if uses_mask:
947+
result_mask[lab, j] = True
948+
else:
949+
out[lab, j] = NAN
914950

915951
for i in range(ncounts):
916952
for j in range(K):
@@ -1164,7 +1200,7 @@ def group_mean(
11641200
mean_t[:, ::1] sumx, compensation
11651201
int64_t[:, ::1] nobs
11661202
Py_ssize_t len_values = len(values), len_labels = len(labels)
1167-
bint isna_entry, uses_mask = mask is not None
1203+
bint isna_entry, isna_result, uses_mask = mask is not None
11681204

11691205
assert min_count == -1, "'min_count' only used in sum and prod"
11701206

@@ -1194,25 +1230,24 @@ def group_mean(
11941230
for j in range(K):
11951231
val = values[i, j]
11961232

1197-
if not skipna and (
1198-
(uses_mask and result_mask[lab, j]) or
1199-
(is_datetimelike and sumx[lab, j] == NPY_NAT) or
1200-
_treat_as_na(sumx[lab, j], False)
1201-
):
1202-
# If sum is already NA, don't add to it. This is important for
1203-
# datetimelike because adding a value to NPY_NAT may not result
1204-
# in NPY_NAT
1205-
continue
1206-
12071233
if uses_mask:
12081234
isna_entry = mask[i, j]
1235+
isna_result = result_mask[lab, j]
12091236
elif is_datetimelike:
12101237
# With group_mean, we cannot just use _treat_as_na bc
12111238
# datetimelike dtypes get cast to float64 instead of
12121239
# to int64.
12131240
isna_entry = val == NPY_NAT
1241+
isna_result = sumx[lab, j] == NPY_NAT
12141242
else:
12151243
isna_entry = _treat_as_na(val, is_datetimelike)
1244+
isna_result = _treat_as_na(sumx[lab, j], is_datetimelike)
1245+
1246+
if not skipna and isna_result:
1247+
# If sum is already NA, don't add to it. This is important for
1248+
# datetimelike because adding a value to NPY_NAT may not result
1249+
# in NPY_NAT
1250+
continue
12161251

12171252
if not isna_entry:
12181253
nobs[lab, j] += 1
@@ -1806,6 +1841,7 @@ cdef group_min_max(
18061841
bint compute_max=True,
18071842
const uint8_t[:, ::1] mask=None,
18081843
uint8_t[:, ::1] result_mask=None,
1844+
bint skipna=True,
18091845
):
18101846
"""
18111847
Compute minimum/maximum of columns of `values`, in row groups `labels`.
@@ -1833,6 +1869,8 @@ cdef group_min_max(
18331869
result_mask : ndarray[bool, ndim=2], optional
18341870
If not None, these specify locations in the output that are NA.
18351871
Modified in-place.
1872+
skipna : bool, default True
1873+
If True, ignore nans in `values`.
18361874
18371875
Notes
18381876
-----
@@ -1841,17 +1879,18 @@ cdef group_min_max(
18411879
"""
18421880
cdef:
18431881
Py_ssize_t i, j, N, K, lab, ngroups = len(counts)
1844-
numeric_t val
1882+
numeric_t val, nan_val
18451883
numeric_t[:, ::1] group_min_or_max
18461884
int64_t[:, ::1] nobs
18471885
bint uses_mask = mask is not None
1848-
bint isna_entry
1886+
bint isna_entry, isna_result
18491887

18501888
if not len(values) == len(labels):
18511889
raise AssertionError("len(index) != len(labels)")
18521890

18531891
min_count = max(min_count, 1)
18541892
nobs = np.zeros((<object>out).shape, dtype=np.int64)
1893+
nan_val = _get_na_val(<numeric_t>0, is_datetimelike)
18551894

18561895
group_min_or_max = np.empty_like(out)
18571896
group_min_or_max[:] = _get_min_or_max(<numeric_t>0, compute_max, is_datetimelike)
@@ -1870,8 +1909,15 @@ cdef group_min_max(
18701909

18711910
if uses_mask:
18721911
isna_entry = mask[i, j]
1912+
isna_result = result_mask[lab, j]
18731913
else:
18741914
isna_entry = _treat_as_na(val, is_datetimelike)
1915+
isna_result = _treat_as_na(group_min_or_max[lab, j],
1916+
is_datetimelike)
1917+
1918+
if not skipna and isna_result:
1919+
# If current min/max is already NA, it will always be NA
1920+
continue
18751921

18761922
if not isna_entry:
18771923
nobs[lab, j] += 1
@@ -1881,6 +1927,11 @@ cdef group_min_max(
18811927
else:
18821928
if val < group_min_or_max[lab, j]:
18831929
group_min_or_max[lab, j] = val
1930+
elif not skipna:
1931+
if uses_mask:
1932+
result_mask[lab, j] = True
1933+
else:
1934+
group_min_or_max[lab, j] = nan_val
18841935

18851936
_check_below_mincount(
18861937
out, uses_mask, result_mask, ngroups, K, nobs, min_count, group_min_or_max
@@ -2012,6 +2063,7 @@ def group_max(
20122063
bint is_datetimelike=False,
20132064
const uint8_t[:, ::1] mask=None,
20142065
uint8_t[:, ::1] result_mask=None,
2066+
bint skipna=True,
20152067
) -> None:
20162068
"""See group_min_max.__doc__"""
20172069
group_min_max(
@@ -2024,6 +2076,7 @@ def group_max(
20242076
compute_max=True,
20252077
mask=mask,
20262078
result_mask=result_mask,
2079+
skipna=skipna,
20272080
)
20282081

20292082

@@ -2038,6 +2091,7 @@ def group_min(
20382091
bint is_datetimelike=False,
20392092
const uint8_t[:, ::1] mask=None,
20402093
uint8_t[:, ::1] result_mask=None,
2094+
bint skipna=True,
20412095
) -> None:
20422096
"""See group_min_max.__doc__"""
20432097
group_min_max(
@@ -2050,6 +2104,7 @@ def group_min(
20502104
compute_max=False,
20512105
mask=mask,
20522106
result_mask=result_mask,
2107+
skipna=skipna,
20532108
)
20542109

20552110

pandas/core/_numba/kernels/min_max_.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def grouped_min_max(
8888
ngroups: int,
8989
min_periods: int,
9090
is_max: bool,
91+
skipna: bool = True,
9192
) -> tuple[np.ndarray, list[int]]:
9293
N = len(labels)
9394
nobs = np.zeros(ngroups, dtype=np.int64)
@@ -97,13 +98,16 @@ def grouped_min_max(
9798
for i in range(N):
9899
lab = labels[i]
99100
val = values[i]
100-
if lab < 0:
101+
if lab < 0 or (nobs[lab] >= 1 and np.isnan(output[lab])):
101102
continue
102103

103104
if values.dtype.kind == "i" or not np.isnan(val):
104105
nobs[lab] += 1
105106
else:
106-
# NaN value cannot be a min/max value
107+
if not skipna:
108+
# If skipna is False and we encounter a NaN,
109+
# both min and max of the group will be NaN
110+
output[lab] = np.nan
107111
continue
108112

109113
if nobs[lab] == 1:

0 commit comments

Comments
 (0)