Skip to content

Commit 6e627c1

Browse files
committed
ENH: Added a min_count keyword to stat funcs (#18876)
The current default is 1, reproducing the behavior of pandas 0.21. The current test suite should pass. Currently, only nansum and nanprod actually do anything with `min_count`. It will not be hard to adjust other nan* methods use it if we want. This was just simplest for now. Additional tests for the new behavior have been added. (cherry picked from commit dbec3c9)
1 parent 7bb204a commit 6e627c1

10 files changed

+445
-83
lines changed

pandas/_libs/groupby_helper.pxi.in

+39-12
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def get_dispatch(dtypes):
3636
def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
3737
ndarray[int64_t] counts,
3838
ndarray[{{c_type}}, ndim=2] values,
39-
ndarray[int64_t] labels):
39+
ndarray[int64_t] labels,
40+
Py_ssize_t min_count=1):
4041
"""
4142
Only aggregates on axis=0
4243
"""
@@ -88,7 +89,7 @@ def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
8889

8990
for i in range(ncounts):
9091
for j in range(K):
91-
if nobs[i, j] == 0:
92+
if nobs[i, j] < min_count:
9293
out[i, j] = NAN
9394
else:
9495
out[i, j] = sumx[i, j]
@@ -99,7 +100,8 @@ def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
99100
def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
100101
ndarray[int64_t] counts,
101102
ndarray[{{c_type}}, ndim=2] values,
102-
ndarray[int64_t] labels):
103+
ndarray[int64_t] labels,
104+
Py_ssize_t min_count=1):
103105
"""
104106
Only aggregates on axis=0
105107
"""
@@ -147,7 +149,7 @@ def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
147149

148150
for i in range(ncounts):
149151
for j in range(K):
150-
if nobs[i, j] == 0:
152+
if nobs[i, j] < min_count:
151153
out[i, j] = NAN
152154
else:
153155
out[i, j] = prodx[i, j]
@@ -159,12 +161,15 @@ def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
159161
def group_var_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
160162
ndarray[int64_t] counts,
161163
ndarray[{{dest_type2}}, ndim=2] values,
162-
ndarray[int64_t] labels):
164+
ndarray[int64_t] labels,
165+
Py_ssize_t min_count=-1):
163166
cdef:
164167
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
165168
{{dest_type2}} val, ct, oldmean
166169
ndarray[{{dest_type2}}, ndim=2] nobs, mean
167170

171+
assert min_count == -1, "'min_count' only used in add and prod"
172+
168173
if not len(values) == len(labels):
169174
raise AssertionError("len(index) != len(labels)")
170175

@@ -208,12 +213,15 @@ def group_var_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
208213
def group_mean_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
209214
ndarray[int64_t] counts,
210215
ndarray[{{dest_type2}}, ndim=2] values,
211-
ndarray[int64_t] labels):
216+
ndarray[int64_t] labels,
217+
Py_ssize_t min_count=-1):
212218
cdef:
213219
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
214220
{{dest_type2}} val, count
215221
ndarray[{{dest_type2}}, ndim=2] sumx, nobs
216222

223+
assert min_count == -1, "'min_count' only used in add and prod"
224+
217225
if not len(values) == len(labels):
218226
raise AssertionError("len(index) != len(labels)")
219227

@@ -263,7 +271,8 @@ def group_mean_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
263271
def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
264272
ndarray[int64_t] counts,
265273
ndarray[{{dest_type2}}, ndim=2] values,
266-
ndarray[int64_t] labels):
274+
ndarray[int64_t] labels,
275+
Py_ssize_t min_count=-1):
267276
"""
268277
Only aggregates on axis=0
269278
"""
@@ -272,6 +281,8 @@ def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
272281
{{dest_type2}} val, count
273282
Py_ssize_t ngroups = len(counts)
274283

284+
assert min_count == -1, "'min_count' only used in add and prod"
285+
275286
if len(labels) == 0:
276287
return
277288

@@ -332,7 +343,8 @@ def get_dispatch(dtypes):
332343
def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
333344
ndarray[int64_t] counts,
334345
ndarray[{{c_type}}, ndim=2] values,
335-
ndarray[int64_t] labels):
346+
ndarray[int64_t] labels,
347+
Py_ssize_t min_count=-1):
336348
"""
337349
Only aggregates on axis=0
338350
"""
@@ -342,6 +354,8 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
342354
ndarray[{{dest_type2}}, ndim=2] resx
343355
ndarray[int64_t, ndim=2] nobs
344356

357+
assert min_count == -1, "'min_count' only used in add and prod"
358+
345359
if not len(values) == len(labels):
346360
raise AssertionError("len(index) != len(labels)")
347361

@@ -382,7 +396,8 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
382396
def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
383397
ndarray[int64_t] counts,
384398
ndarray[{{c_type}}, ndim=2] values,
385-
ndarray[int64_t] labels, int64_t rank):
399+
ndarray[int64_t] labels, int64_t rank,
400+
Py_ssize_t min_count=-1):
386401
"""
387402
Only aggregates on axis=0
388403
"""
@@ -392,6 +407,8 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
392407
ndarray[{{dest_type2}}, ndim=2] resx
393408
ndarray[int64_t, ndim=2] nobs
394409

410+
assert min_count == -1, "'min_count' only used in add and prod"
411+
395412
if not len(values) == len(labels):
396413
raise AssertionError("len(index) != len(labels)")
397414

@@ -455,7 +472,8 @@ def get_dispatch(dtypes):
455472
def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
456473
ndarray[int64_t] counts,
457474
ndarray[{{dest_type2}}, ndim=2] values,
458-
ndarray[int64_t] labels):
475+
ndarray[int64_t] labels,
476+
Py_ssize_t min_count=-1):
459477
"""
460478
Only aggregates on axis=0
461479
"""
@@ -464,6 +482,8 @@ def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
464482
{{dest_type2}} val, count
465483
ndarray[{{dest_type2}}, ndim=2] maxx, nobs
466484

485+
assert min_count == -1, "'min_count' only used in add and prod"
486+
467487
if not len(values) == len(labels):
468488
raise AssertionError("len(index) != len(labels)")
469489

@@ -526,7 +546,8 @@ def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
526546
def group_min_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
527547
ndarray[int64_t] counts,
528548
ndarray[{{dest_type2}}, ndim=2] values,
529-
ndarray[int64_t] labels):
549+
ndarray[int64_t] labels,
550+
Py_ssize_t min_count=-1):
530551
"""
531552
Only aggregates on axis=0
532553
"""
@@ -535,6 +556,8 @@ def group_min_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
535556
{{dest_type2}} val, count
536557
ndarray[{{dest_type2}}, ndim=2] minx, nobs
537558

559+
assert min_count == -1, "'min_count' only used in add and prod"
560+
538561
if not len(values) == len(labels):
539562
raise AssertionError("len(index) != len(labels)")
540563

@@ -686,7 +709,8 @@ def group_cummax_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
686709
def group_median_float64(ndarray[float64_t, ndim=2] out,
687710
ndarray[int64_t] counts,
688711
ndarray[float64_t, ndim=2] values,
689-
ndarray[int64_t] labels):
712+
ndarray[int64_t] labels,
713+
Py_ssize_t min_count=-1):
690714
"""
691715
Only aggregates on axis=0
692716
"""
@@ -695,6 +719,9 @@ def group_median_float64(ndarray[float64_t, ndim=2] out,
695719
ndarray[int64_t] _counts
696720
ndarray data
697721
float64_t* ptr
722+
723+
assert min_count == -1, "'min_count' only used in add and prod"
724+
698725
ngroups = len(counts)
699726
N, K = (<object> values).shape
700727

pandas/core/generic.py

+96-8
Original file line numberDiff line numberDiff line change
@@ -6921,7 +6921,8 @@ def _add_numeric_operations(cls):
69216921
@Substitution(outname='mad',
69226922
desc="Return the mean absolute deviation of the values "
69236923
"for the requested axis",
6924-
name1=name, name2=name2, axis_descr=axis_descr)
6924+
name1=name, name2=name2, axis_descr=axis_descr,
6925+
min_count='', examples='')
69256926
@Appender(_num_doc)
69266927
def mad(self, axis=None, skipna=None, level=None):
69276928
if skipna is None:
@@ -6962,7 +6963,8 @@ def mad(self, axis=None, skipna=None, level=None):
69626963
@Substitution(outname='compounded',
69636964
desc="Return the compound percentage of the values for "
69646965
"the requested axis", name1=name, name2=name2,
6965-
axis_descr=axis_descr)
6966+
axis_descr=axis_descr,
6967+
min_count='', examples='')
69666968
@Appender(_num_doc)
69676969
def compound(self, axis=None, skipna=None, level=None):
69686970
if skipna is None:
@@ -6986,10 +6988,10 @@ def compound(self, axis=None, skipna=None, level=None):
69866988
lambda y, axis: np.maximum.accumulate(y, axis), "max",
69876989
-np.inf, np.nan)
69886990

6989-
cls.sum = _make_stat_function(
6991+
cls.sum = _make_min_count_stat_function(
69906992
cls, 'sum', name, name2, axis_descr,
69916993
'Return the sum of the values for the requested axis',
6992-
nanops.nansum)
6994+
nanops.nansum, _sum_examples)
69936995
cls.mean = _make_stat_function(
69946996
cls, 'mean', name, name2, axis_descr,
69956997
'Return the mean of the values for the requested axis',
@@ -7005,10 +7007,10 @@ def compound(self, axis=None, skipna=None, level=None):
70057007
"by N-1\n",
70067008
nanops.nankurt)
70077009
cls.kurtosis = cls.kurt
7008-
cls.prod = _make_stat_function(
7010+
cls.prod = _make_min_count_stat_function(
70097011
cls, 'prod', name, name2, axis_descr,
70107012
'Return the product of the values for the requested axis',
7011-
nanops.nanprod)
7013+
nanops.nanprod, _prod_examples)
70127014
cls.product = cls.prod
70137015
cls.median = _make_stat_function(
70147016
cls, 'median', name, name2, axis_descr,
@@ -7139,10 +7141,13 @@ def _doc_parms(cls):
71397141
numeric_only : boolean, default None
71407142
Include only float, int, boolean columns. If None, will attempt to use
71417143
everything, then use only numeric data. Not implemented for Series.
7144+
%(min_count)s\
71427145
71437146
Returns
71447147
-------
7145-
%(outname)s : %(name1)s or %(name2)s (if level specified)\n"""
7148+
%(outname)s : %(name1)s or %(name2)s (if level specified)
7149+
7150+
%(examples)s"""
71467151

71477152
_num_ddof_doc = """
71487153
@@ -7210,9 +7215,92 @@ def _doc_parms(cls):
72107215
"""
72117216

72127217

7218+
_sum_examples = """\
7219+
Examples
7220+
--------
7221+
By default, the sum of an empty series is ``NaN``.
7222+
7223+
>>> pd.Series([]).sum() # min_count=1 is the default
7224+
nan
7225+
7226+
This can be controlled with the ``min_count`` parameter. For example, if
7227+
you'd like the sum of an empty series to be 0, pass ``min_count=0``.
7228+
7229+
>>> pd.Series([]).sum(min_count=0)
7230+
0.0
7231+
7232+
Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and
7233+
empty series identically.
7234+
7235+
>>> pd.Series([np.nan]).sum()
7236+
nan
7237+
7238+
>>> pd.Series([np.nan]).sum(min_count=0)
7239+
0.0
7240+
"""
7241+
7242+
_prod_examples = """\
7243+
Examples
7244+
--------
7245+
By default, the product of an empty series is ``NaN``
7246+
7247+
>>> pd.Series([]).prod()
7248+
nan
7249+
7250+
This can be controlled with the ``min_count`` parameter
7251+
7252+
>>> pd.Series([]).prod(min_count=0)
7253+
1.0
7254+
7255+
Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and
7256+
empty series identically.
7257+
7258+
>>> pd.Series([np.nan]).prod()
7259+
nan
7260+
7261+
>>> pd.Series([np.nan]).sum(min_count=0)
7262+
1.0
7263+
"""
7264+
7265+
7266+
_min_count_stub = """\
7267+
min_count : int, default 1
7268+
The required number of valid values to perform the operation. If fewer than
7269+
``min_count`` non-NA values are present the result will be NA.
7270+
7271+
.. versionadded :: 0.21.2
7272+
7273+
Added with the default being 1. This means the sum or product
7274+
of an all-NA or empty series is ``NaN``.
7275+
"""
7276+
7277+
7278+
def _make_min_count_stat_function(cls, name, name1, name2, axis_descr, desc,
7279+
f, examples):
7280+
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7281+
axis_descr=axis_descr, min_count=_min_count_stub,
7282+
examples=examples)
7283+
@Appender(_num_doc)
7284+
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
7285+
min_count=1,
7286+
**kwargs):
7287+
nv.validate_stat_func(tuple(), kwargs, fname=name)
7288+
if skipna is None:
7289+
skipna = True
7290+
if axis is None:
7291+
axis = self._stat_axis_number
7292+
if level is not None:
7293+
return self._agg_by_level(name, axis=axis, level=level,
7294+
skipna=skipna, min_count=min_count)
7295+
return self._reduce(f, name, axis=axis, skipna=skipna,
7296+
numeric_only=numeric_only, min_count=min_count)
7297+
7298+
return set_function_name(stat_func, name, cls)
7299+
7300+
72137301
def _make_stat_function(cls, name, name1, name2, axis_descr, desc, f):
72147302
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7215-
axis_descr=axis_descr)
7303+
axis_descr=axis_descr, min_count='', examples='')
72167304
@Appender(_num_doc)
72177305
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
72187306
**kwargs):

0 commit comments

Comments
 (0)