Skip to content

Commit b9decb6

Browse files
committed
ENH: Added a min_count keyword to stat funcs
The current default is 1, reproducing the behavior of pandas 0.21. The current test suite should pass. Currently, only nansum and nanprod actually do anything with `min_count`. It will not be hard to adjust other nan* methods use it if we want. This was just simplest for now. Additional tests for the new behavior have been added.
1 parent 316acbf commit b9decb6

10 files changed

+445
-83
lines changed

pandas/_libs/groupby_helper.pxi.in

+39-12
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def get_dispatch(dtypes):
3636
def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
3737
ndarray[int64_t] counts,
3838
ndarray[{{c_type}}, ndim=2] values,
39-
ndarray[int64_t] labels):
39+
ndarray[int64_t] labels,
40+
Py_ssize_t min_count=1):
4041
"""
4142
Only aggregates on axis=0
4243
"""
@@ -88,7 +89,7 @@ def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
8889

8990
for i in range(ncounts):
9091
for j in range(K):
91-
if nobs[i, j] == 0:
92+
if nobs[i, j] < min_count:
9293
out[i, j] = NAN
9394
else:
9495
out[i, j] = sumx[i, j]
@@ -99,7 +100,8 @@ def group_add_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
99100
def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
100101
ndarray[int64_t] counts,
101102
ndarray[{{c_type}}, ndim=2] values,
102-
ndarray[int64_t] labels):
103+
ndarray[int64_t] labels,
104+
Py_ssize_t min_count=1):
103105
"""
104106
Only aggregates on axis=0
105107
"""
@@ -147,7 +149,7 @@ def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
147149

148150
for i in range(ncounts):
149151
for j in range(K):
150-
if nobs[i, j] == 0:
152+
if nobs[i, j] < min_count:
151153
out[i, j] = NAN
152154
else:
153155
out[i, j] = prodx[i, j]
@@ -159,12 +161,15 @@ def group_prod_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
159161
def group_var_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
160162
ndarray[int64_t] counts,
161163
ndarray[{{dest_type2}}, ndim=2] values,
162-
ndarray[int64_t] labels):
164+
ndarray[int64_t] labels,
165+
Py_ssize_t min_count=-1):
163166
cdef:
164167
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
165168
{{dest_type2}} val, ct, oldmean
166169
ndarray[{{dest_type2}}, ndim=2] nobs, mean
167170

171+
assert min_count == -1, "'min_count' only used in add and prod"
172+
168173
if not len(values) == len(labels):
169174
raise AssertionError("len(index) != len(labels)")
170175

@@ -208,12 +213,15 @@ def group_var_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
208213
def group_mean_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
209214
ndarray[int64_t] counts,
210215
ndarray[{{dest_type2}}, ndim=2] values,
211-
ndarray[int64_t] labels):
216+
ndarray[int64_t] labels,
217+
Py_ssize_t min_count=-1):
212218
cdef:
213219
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
214220
{{dest_type2}} val, count
215221
ndarray[{{dest_type2}}, ndim=2] sumx, nobs
216222

223+
assert min_count == -1, "'min_count' only used in add and prod"
224+
217225
if not len(values) == len(labels):
218226
raise AssertionError("len(index) != len(labels)")
219227

@@ -263,7 +271,8 @@ def group_mean_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
263271
def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
264272
ndarray[int64_t] counts,
265273
ndarray[{{dest_type2}}, ndim=2] values,
266-
ndarray[int64_t] labels):
274+
ndarray[int64_t] labels,
275+
Py_ssize_t min_count=-1):
267276
"""
268277
Only aggregates on axis=0
269278
"""
@@ -272,6 +281,8 @@ def group_ohlc_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
272281
{{dest_type2}} val, count
273282
Py_ssize_t ngroups = len(counts)
274283

284+
assert min_count == -1, "'min_count' only used in add and prod"
285+
275286
if len(labels) == 0:
276287
return
277288

@@ -332,7 +343,8 @@ def get_dispatch(dtypes):
332343
def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
333344
ndarray[int64_t] counts,
334345
ndarray[{{c_type}}, ndim=2] values,
335-
ndarray[int64_t] labels):
346+
ndarray[int64_t] labels,
347+
Py_ssize_t min_count=-1):
336348
"""
337349
Only aggregates on axis=0
338350
"""
@@ -342,6 +354,8 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
342354
ndarray[{{dest_type2}}, ndim=2] resx
343355
ndarray[int64_t, ndim=2] nobs
344356

357+
assert min_count == -1, "'min_count' only used in add and prod"
358+
345359
if not len(values) == len(labels):
346360
raise AssertionError("len(index) != len(labels)")
347361

@@ -382,7 +396,8 @@ def group_last_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
382396
def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
383397
ndarray[int64_t] counts,
384398
ndarray[{{c_type}}, ndim=2] values,
385-
ndarray[int64_t] labels, int64_t rank):
399+
ndarray[int64_t] labels, int64_t rank,
400+
Py_ssize_t min_count=-1):
386401
"""
387402
Only aggregates on axis=0
388403
"""
@@ -392,6 +407,8 @@ def group_nth_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
392407
ndarray[{{dest_type2}}, ndim=2] resx
393408
ndarray[int64_t, ndim=2] nobs
394409

410+
assert min_count == -1, "'min_count' only used in add and prod"
411+
395412
if not len(values) == len(labels):
396413
raise AssertionError("len(index) != len(labels)")
397414

@@ -455,7 +472,8 @@ def get_dispatch(dtypes):
455472
def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
456473
ndarray[int64_t] counts,
457474
ndarray[{{dest_type2}}, ndim=2] values,
458-
ndarray[int64_t] labels):
475+
ndarray[int64_t] labels,
476+
Py_ssize_t min_count=-1):
459477
"""
460478
Only aggregates on axis=0
461479
"""
@@ -464,6 +482,8 @@ def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
464482
{{dest_type2}} val, count
465483
ndarray[{{dest_type2}}, ndim=2] maxx, nobs
466484

485+
assert min_count == -1, "'min_count' only used in add and prod"
486+
467487
if not len(values) == len(labels):
468488
raise AssertionError("len(index) != len(labels)")
469489

@@ -526,7 +546,8 @@ def group_max_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
526546
def group_min_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
527547
ndarray[int64_t] counts,
528548
ndarray[{{dest_type2}}, ndim=2] values,
529-
ndarray[int64_t] labels):
549+
ndarray[int64_t] labels,
550+
Py_ssize_t min_count=-1):
530551
"""
531552
Only aggregates on axis=0
532553
"""
@@ -535,6 +556,8 @@ def group_min_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
535556
{{dest_type2}} val, count
536557
ndarray[{{dest_type2}}, ndim=2] minx, nobs
537558

559+
assert min_count == -1, "'min_count' only used in add and prod"
560+
538561
if not len(values) == len(labels):
539562
raise AssertionError("len(index) != len(labels)")
540563

@@ -686,7 +709,8 @@ def group_cummax_{{name}}(ndarray[{{dest_type2}}, ndim=2] out,
686709
def group_median_float64(ndarray[float64_t, ndim=2] out,
687710
ndarray[int64_t] counts,
688711
ndarray[float64_t, ndim=2] values,
689-
ndarray[int64_t] labels):
712+
ndarray[int64_t] labels,
713+
Py_ssize_t min_count=-1):
690714
"""
691715
Only aggregates on axis=0
692716
"""
@@ -695,6 +719,9 @@ def group_median_float64(ndarray[float64_t, ndim=2] out,
695719
ndarray[int64_t] _counts
696720
ndarray data
697721
float64_t* ptr
722+
723+
assert min_count == -1, "'min_count' only used in add and prod"
724+
698725
ngroups = len(counts)
699726
N, K = (<object> values).shape
700727

pandas/core/generic.py

+96-8
Original file line numberDiff line numberDiff line change
@@ -7322,7 +7322,8 @@ def _add_numeric_operations(cls):
73227322
@Substitution(outname='mad',
73237323
desc="Return the mean absolute deviation of the values "
73247324
"for the requested axis",
7325-
name1=name, name2=name2, axis_descr=axis_descr)
7325+
name1=name, name2=name2, axis_descr=axis_descr,
7326+
min_count='', examples='')
73267327
@Appender(_num_doc)
73277328
def mad(self, axis=None, skipna=None, level=None):
73287329
if skipna is None:
@@ -7363,7 +7364,8 @@ def mad(self, axis=None, skipna=None, level=None):
73637364
@Substitution(outname='compounded',
73647365
desc="Return the compound percentage of the values for "
73657366
"the requested axis", name1=name, name2=name2,
7366-
axis_descr=axis_descr)
7367+
axis_descr=axis_descr,
7368+
min_count='', examples='')
73677369
@Appender(_num_doc)
73687370
def compound(self, axis=None, skipna=None, level=None):
73697371
if skipna is None:
@@ -7387,10 +7389,10 @@ def compound(self, axis=None, skipna=None, level=None):
73877389
lambda y, axis: np.maximum.accumulate(y, axis), "max",
73887390
-np.inf, np.nan)
73897391

7390-
cls.sum = _make_stat_function(
7392+
cls.sum = _make_min_count_stat_function(
73917393
cls, 'sum', name, name2, axis_descr,
73927394
'Return the sum of the values for the requested axis',
7393-
nanops.nansum)
7395+
nanops.nansum, _sum_examples)
73947396
cls.mean = _make_stat_function(
73957397
cls, 'mean', name, name2, axis_descr,
73967398
'Return the mean of the values for the requested axis',
@@ -7406,10 +7408,10 @@ def compound(self, axis=None, skipna=None, level=None):
74067408
"by N-1\n",
74077409
nanops.nankurt)
74087410
cls.kurtosis = cls.kurt
7409-
cls.prod = _make_stat_function(
7411+
cls.prod = _make_min_count_stat_function(
74107412
cls, 'prod', name, name2, axis_descr,
74117413
'Return the product of the values for the requested axis',
7412-
nanops.nanprod)
7414+
nanops.nanprod, _prod_examples)
74137415
cls.product = cls.prod
74147416
cls.median = _make_stat_function(
74157417
cls, 'median', name, name2, axis_descr,
@@ -7540,10 +7542,13 @@ def _doc_parms(cls):
75407542
numeric_only : boolean, default None
75417543
Include only float, int, boolean columns. If None, will attempt to use
75427544
everything, then use only numeric data. Not implemented for Series.
7545+
%(min_count)s\
75437546
75447547
Returns
75457548
-------
7546-
%(outname)s : %(name1)s or %(name2)s (if level specified)\n"""
7549+
%(outname)s : %(name1)s or %(name2)s (if level specified)
7550+
7551+
%(examples)s"""
75477552

75487553
_num_ddof_doc = """
75497554
@@ -7611,9 +7616,92 @@ def _doc_parms(cls):
76117616
"""
76127617

76137618

7619+
_sum_examples = """\
7620+
Examples
7621+
--------
7622+
By default, the sum of an empty series is ``NaN``.
7623+
7624+
>>> pd.Series([]).sum() # min_count=1 is the default
7625+
nan
7626+
7627+
This can be controlled with the ``min_count`` parameter. For example, if
7628+
you'd like the sum of an empty series to be 0, pass ``min_count=0``.
7629+
7630+
>>> pd.Series([]).sum(min_count=0)
7631+
0.0
7632+
7633+
Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and
7634+
empty series identically.
7635+
7636+
>>> pd.Series([np.nan]).sum()
7637+
nan
7638+
7639+
>>> pd.Series([np.nan]).sum(min_count=0)
7640+
0.0
7641+
"""
7642+
7643+
_prod_examples = """\
7644+
Examples
7645+
--------
7646+
By default, the product of an empty series is ``NaN``
7647+
7648+
>>> pd.Series([]).prod()
7649+
nan
7650+
7651+
This can be controlled with the ``min_count`` parameter
7652+
7653+
>>> pd.Series([]).prod(min_count=0)
7654+
1.0
7655+
7656+
Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and
7657+
empty series identically.
7658+
7659+
>>> pd.Series([np.nan]).prod()
7660+
nan
7661+
7662+
>>> pd.Series([np.nan]).sum(min_count=0)
7663+
1.0
7664+
"""
7665+
7666+
7667+
_min_count_stub = """\
7668+
min_count : int, default 1
7669+
The required number of valid values to perform the operation. If fewer than
7670+
``min_count`` non-NA values are present the result will be NA.
7671+
7672+
.. versionadded :: 0.21.2
7673+
7674+
Added with the default being 1. This means the sum or product
7675+
of an all-NA or empty series is ``NaN``.
7676+
"""
7677+
7678+
7679+
def _make_min_count_stat_function(cls, name, name1, name2, axis_descr, desc,
7680+
f, examples):
7681+
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7682+
axis_descr=axis_descr, min_count=_min_count_stub,
7683+
examples=examples)
7684+
@Appender(_num_doc)
7685+
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
7686+
min_count=1,
7687+
**kwargs):
7688+
nv.validate_stat_func(tuple(), kwargs, fname=name)
7689+
if skipna is None:
7690+
skipna = True
7691+
if axis is None:
7692+
axis = self._stat_axis_number
7693+
if level is not None:
7694+
return self._agg_by_level(name, axis=axis, level=level,
7695+
skipna=skipna, min_count=min_count)
7696+
return self._reduce(f, name, axis=axis, skipna=skipna,
7697+
numeric_only=numeric_only, min_count=min_count)
7698+
7699+
return set_function_name(stat_func, name, cls)
7700+
7701+
76147702
def _make_stat_function(cls, name, name1, name2, axis_descr, desc, f):
76157703
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
7616-
axis_descr=axis_descr)
7704+
axis_descr=axis_descr, min_count='', examples='')
76177705
@Appender(_num_doc)
76187706
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
76197707
**kwargs):

0 commit comments

Comments
 (0)