Skip to content

Commit dbec3c9

Browse files
ENH: Added a min_count keyword to stat funcs (#18876)
The current default is 1, reproducing the behavior of pandas 0.21. The current test suite should pass. Currently, only nansum and nanprod actually do anything with `min_count`. It will not be hard to adjust other nan* methods use it if we want. This was just simplest for now. Additional tests for the new behavior have been added.
1 parent a9f82df commit dbec3c9

10 files changed

+445
-83
lines changed

pandas/_libs/groupby_helper.pxi.in

+39-12
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def get_dispatch(dtypes):
3636
def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
3737
ndarray[int64_t] counts,
3838
ndarray[{{c_type}}, ndim=2] values,
39-
ndarray[int64_t] labels):
39+
ndarray[int64_t] labels,
40+
Py_ssize_t min_count=1):
4041
"""
4142
Only aggregates on axis=0
4243
"""
@@ -88,7 +89,7 @@ def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
8889

8990
for i in range(ncounts):
9091
for j in range(K):
91-
if nobs[i, j] == 0:
92+
if nobs[i, j] < min_count:
9293
out[i, j] = NAN
9394
else:
9495
out[i, j] = sumx[i, j]
@@ -99,7 +100,8 @@ def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
99100
def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
100101
ndarray[int64_t] counts,
101102
ndarray[{{c_type}}, ndim=2] values,
102-
ndarray[int64_t] labels):
103+
ndarray[int64_t] labels,
104+
Py_ssize_t min_count=1):
103105
"""
104106
Only aggregates on axis=0
105107
"""
@@ -147,7 +149,7 @@ def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
147149

148150
for i in range(ncounts):
149151
for j in range(K):
150-
if nobs[i, j] == 0:
152+
if nobs[i, j] < min_count:
151153
out[i, j] = NAN
152154
else:
153155
out[i, j] = prodx[i, j]
@@ -159,12 +161,15 @@ def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
159161
def group_var_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
160162
ndarray[int64_t] counts,
161163
ndarray[{{dest_type2}}, ndim=2] values,
162-
ndarray[int64_t] labels):
164+
ndarray[int64_t] labels,
165+
Py_ssize_t min_count=-1):
163166
cdef:
164167
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
165168
{{dest_type2}} val, ct, oldmean
166169
ndarray[{{dest_type2}}, ndim=2] nobs, mean
167170

171+
assert min_count == -1, "'min_count' only used in add and prod"
172+
168173
if not len(values) == len(labels):
169174
raise AssertionError("len(index) != len(labels)")
170175

@@ -208,12 +213,15 @@ def group_var_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
208213
def group_mean_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
209214
ndarray[int64_t] counts,
210215
ndarray[{{dest_type2}}, ndim=2] values,
211-
ndarray[int64_t] labels):
216+
ndarray[int64_t] labels,
217+
Py_ssize_t min_count=-1):
212218
cdef:
213219
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
214220
{{dest_type2}} val, count
215221
ndarray[{{dest_type2}}, ndim=2] sumx, nobs
216222

223+
assert min_count == -1, "'min_count' only used in add and prod"
224+
217225
if not len(values) == len(labels):
218226
raise AssertionError("len(index) != len(labels)")
219227

@@ -263,7 +271,8 @@ def group_mean_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
263271
def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
264272
ndarray[int64_t] counts,
265273
ndarray[{{dest_type2}}, ndim=2] values,
266-
ndarray[int64_t] labels):
274+
ndarray[int64_t] labels,
275+
Py_ssize_t min_count=-1):
267276
"""
268277
Only aggregates on axis=0
269278
"""
@@ -272,6 +281,8 @@ def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
272281
{{dest_type2}} val, count
273282
Py_ssize_t ngroups = len(counts)
274283

284+
assert min_count == -1, "'min_count' only used in add and prod"
285+
275286
if len(labels) == 0:
276287
return
277288

@@ -332,7 +343,8 @@ def get_dispatch(dtypes):
332343
def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
333344
ndarray[int64_t] counts,
334345
ndarray[{{c_type}}, ndim=2] values,
335-
ndarray[int64_t] labels):
346+
ndarray[int64_t] labels,
347+
Py_ssize_t min_count=-1):
336348
"""
337349
Only aggregates on axis=0
338350
"""
@@ -342,6 +354,8 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
342354
ndarray[{{dest_type2}}, ndim=2] resx
343355
ndarray[int64_t, ndim=2] nobs
344356

357+
assert min_count == -1, "'min_count' only used in add and prod"
358+
345359
if not len(values) == len(labels):
346360
raise AssertionError("len(index) != len(labels)")
347361

@@ -382,7 +396,8 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
382396
def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
383397
ndarray[int64_t] counts,
384398
ndarray[{{c_type}}, ndim=2] values,
385-
ndarray[int64_t] labels, int64_t rank):
399+
ndarray[int64_t] labels, int64_t rank,
400+
Py_ssize_t min_count=-1):
386401
"""
387402
Only aggregates on axis=0
388403
"""
@@ -392,6 +407,8 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
392407
ndarray[{{dest_type2}}, ndim=2] resx
393408
ndarray[int64_t, ndim=2] nobs
394409

410+
assert min_count == -1, "'min_count' only used in add and prod"
411+
395412
if not len(values) == len(labels):
396413
raise AssertionError("len(index) != len(labels)")
397414

@@ -455,7 +472,8 @@ def get_dispatch(dtypes):
455472
def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
456473
ndarray[int64_t] counts,
457474
ndarray[{{dest_type2}}, ndim=2] values,
458-
ndarray[int64_t] labels):
475+
ndarray[int64_t] labels,
476+
Py_ssize_t min_count=-1):
459477
"""
460478
Only aggregates on axis=0
461479
"""
@@ -464,6 +482,8 @@ def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
464482
{{dest_type2}} val, count
465483
ndarray[{{dest_type2}}, ndim=2] maxx, nobs
466484

485+
assert min_count == -1, "'min_count' only used in add and prod"
486+
467487
if not len(values) == len(labels):
468488
raise AssertionError("len(index) != len(labels)")
469489

@@ -526,7 +546,8 @@ def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
526546
def group_min_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
527547
ndarray[int64_t] counts,
528548
ndarray[{{dest_type2}}, ndim=2] values,
529-
ndarray[int64_t] labels):
549+
ndarray[int64_t] labels,
550+
Py_ssize_t min_count=-1):
530551
"""
531552
Only aggregates on axis=0
532553
"""
@@ -535,6 +556,8 @@ def group_min_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
535556
{{dest_type2}} val, count
536557
ndarray[{{dest_type2}}, ndim=2] minx, nobs
537558

559+
assert min_count == -1, "'min_count' only used in add and prod"
560+
538561
if not len(values) == len(labels):
539562
raise AssertionError("len(index) != len(labels)")
540563

@@ -686,7 +709,8 @@ def group_cummax_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
686709
def group_median_float64(ndarray[float64_t, ndim=2] out,
687710
ndarray[int64_t] counts,
688711
ndarray[float64_t, ndim=2] values,
689-
ndarray[int64_t] labels):
712+
ndarray[int64_t] labels,
713+
Py_ssize_t min_count=-1):
690714
"""
691715
Only aggregates on axis=0
692716
"""
@@ -695,6 +719,9 @@ def group_median_float64(ndarray[float64_t, ndim=2] out,
695719
ndarray[int64_t] _counts
696720
ndarray data
697721
float64_t* ptr
722+
723+
assert min_count == -1, "'min_count' only used in add and prod"
724+
698725
ngroups = len(counts)
699726
N, K = (<object> values).shape
700727

pandas/core/generic.py

+96-8
Original file line numberDiff line numberDiff line change
@@ -7322,7 +7322,8 @@ def _add_numeric_operations(cls):
73227322
@Substitution(outname='mad',
73237323
desc="Return the mean absolute deviation of the values "
73247324
"for the requested axis",
7325-
name1=name, name2=name2, axis_descr=axis_descr)
7325+
name1=name, name2=name2, axis_descr=axis_descr,
7326+
min_count='', examples='')
73267327
@Appender(_num_doc)
73277328
def mad(self, axis=None, skipna=None, level=None):
73287329
if skipna is None:
@@ -7363,7 +7364,8 @@ def mad(self, axis=None, skipna=None, level=None):
73637364
@Substitution(outname='compounded',
73647365
desc="Return the compound percentage of the values for "
73657366
"the requested axis", name1=name, name2=name2,
7366-
axis_descr=axis_descr)
7367+
axis_descr=axis_descr,
7368+
min_count='', examples='')
73677369
@Appender(_num_doc)
73687370
def compound(self, axis=None, skipna=None, level=None):
73697371
if skipna is None:
@@ -7387,10 +7389,10 @@ def compound(self, axis=None, skipna=None, level=None):
73877389
lambda y, axis: np.maximum.accumulate(y, axis), "max",
73887390
-np.inf, np.nan)
73897391

7390-
cls.sum = _make_stat_function(
7392+
cls.sum = _make_min_count_stat_function(
73917393
cls, 'sum', name, name2, axis_descr,
73927394
'Return the sum of the values for the requested axis',
7393-
nanops.nansum)
7395+
nanops.nansum, _sum_examples)
73947396
cls.mean = _make_stat_function(
73957397
cls, 'mean', name, name2, axis_descr,
73967398
'Return the mean of the values for the requested axis',
@@ -7406,10 +7408,10 @@ def compound(self, axis=None, skipna=None, level=None):
74067408
"by N-1\n",
74077409
nanops.nankurt)
74087410
cls.kurtosis = cls.kurt
7409-
cls.prod = _make_stat_function(
7411+
cls.prod = _make_min_count_stat_function(
74107412
cls, 'prod', name, name2, axis_descr,
74117413
'Return the product of the values for the requested axis',
7412-
nanops.nanprod)
7414+
nanops.nanprod, _prod_examples)
74137415
cls.product = cls.prod
74147416
cls.median = _make_stat_function(
74157417
cls, 'median', name, name2, axis_descr,
@@ -7540,10 +7542,13 @@ def _doc_parms(cls):
75407542
numeric_only : boolean, default None
75417543
Include only float, int, boolean columns. If None, will attempt to use
75427544
everything, then use only numeric data. Not implemented for Series.
7545+
%(min_count)s\
75437546
75447547
Returns
75457548
-------
7546-
%(outname)s : %(name1)s or %(name2)s (if level specified)\n"""
7549+
%(outname)s : %(name1)s or %(name2)s (if level specified)
7550+
7551+
%(examples)s"""
75477552

75487553
_num_ddof_doc = """
75497554
@@ -7611,9 +7616,92 @@ def _doc_parms(cls):
76117616
"""
76127617

76137618

7619+
_sum_examples = """\
7620+
Examples
7621+
--------
7622+
By default, the sum of an empty series is ``NaN``.
7623+
7624+
>>> pd.Series([]).sum() # min_count=1 is the default
7625+
nan
7626+
7627+
This can be controlled with the ``min_count`` parameter. For example, if
7628+
you'd like the sum of an empty series to be 0, pass ``min_count=0``.
7629+
7630+
>>> pd.Series([]).sum(min_count=0)
7631+
0.0
7632+
7633+
Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and
7634+
empty series identically.
7635+
7636+
>>> pd.Series([np.nan]).sum()
7637+
nan
7638+
7639+
>>> pd.Series([np.nan]).sum(min_count=0)
7640+
0.0
7641+
"""
7642+
7643+
_prod_examples = """\
7644+
Examples
7645+
--------
7646+
By default, the product of an empty series is ``NaN``
7647+
7648+
>>> pd.Series([]).prod()
7649+
nan
7650+
7651+
This can be controlled with the ``min_count`` parameter
7652+
7653+
>>> pd.Series([]).prod(min_count=0)
7654+
1.0
7655+
7656+
Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and
7657+
empty series identically.
7658+
7659+
>>> pd.Series([np.nan]).prod()
7660+
nan
7661+
7662+
>>> pd.Series([np.nan]).sum(min_count=0)
7663+
1.0
7664+
"""
7665+
7666+
7667+
_min_count_stub = """\
7668+
min_count : int, default 1
7669+
The required number of valid values to perform the operation. If fewer than
7670+
``min_count`` non-NA values are present the result will be NA.
7671+
7672+
.. versionadded :: 0.21.2
7673+
7674+
Added with the default being 1. This means the sum or product
7675+
of an all-NA or empty series is ``NaN``.
7676+
"""
7677+
7678+
7679+
def _make_min_count_stat_function(cls, name, name1, name2, axis_descr, desc,
7680+
f, examples):
7681+
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7682+
axis_descr=axis_descr, min_count=_min_count_stub,
7683+
examples=examples)
7684+
@Appender(_num_doc)
7685+
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
7686+
min_count=1,
7687+
**kwargs):
7688+
nv.validate_stat_func(tuple(), kwargs, fname=name)
7689+
if skipna is None:
7690+
skipna = True
7691+
if axis is None:
7692+
axis = self._stat_axis_number
7693+
if level is not None:
7694+
return self._agg_by_level(name, axis=axis, level=level,
7695+
skipna=skipna, min_count=min_count)
7696+
return self._reduce(f, name, axis=axis, skipna=skipna,
7697+
numeric_only=numeric_only, min_count=min_count)
7698+
7699+
return set_function_name(stat_func, name, cls)
7700+
7701+
76147702
def _make_stat_function(cls, name, name1, name2, axis_descr, desc, f):
76157703
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7616-
axis_descr=axis_descr)
7704+
axis_descr=axis_descr, min_count='', examples='')
76177705
@Appender(_num_doc)
76187706
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
76197707
**kwargs):

0 commit comments

Comments
 (0)