Skip to content

Commit 2a60c56

Browse files
authored
ENH: add masked algorithm for mean() function (#34814)
1 parent a380726 commit 2a60c56

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ Deprecations
167167
Performance improvements
168168
~~~~~~~~~~~~~~~~~~~~~~~~
169169
- Performance improvement in :meth:`IntervalIndex.isin` (:issue:`38353`)
170-
-
170+
- Performance improvement in :meth:`Series.mean` for nullable data types (:issue:`34814`)
171171
-
172172

173173
.. ---------------------------------------------------------------------------

pandas/core/array_algos/masked_reductions.py

+9
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,12 @@ def min(values: np.ndarray, mask: np.ndarray, *, skipna: bool = True):
107107

108108
def max(values: np.ndarray, mask: np.ndarray, *, skipna: bool = True):
109109
return _minmax(np.max, values=values, mask=mask, skipna=skipna)
110+
111+
112+
def mean(values: np.ndarray, mask: np.ndarray, skipna: bool = True):
113+
if not values.size or mask.all():
114+
return libmissing.NA
115+
_sum = _sumprod(np.sum, values=values, mask=mask, skipna=skipna)
116+
count = np.count_nonzero(~mask)
117+
mean_value = _sum / count
118+
return mean_value

pandas/core/arrays/masked.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
394394
data = self._data
395395
mask = self._mask
396396

397-
if name in {"sum", "prod", "min", "max"}:
397+
if name in {"sum", "prod", "min", "max", "mean"}:
398398
op = getattr(masked_reductions, name)
399399
return op(data, mask, skipna=skipna, **kwargs)
400400

pandas/tests/reductions/test_reductions.py

+17
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,23 @@ def test_empty_multi(self, method, unit):
688688
expected = Series([1, np.nan], index=["a", "b"])
689689
tm.assert_series_equal(result, expected)
690690

691+
@pytest.mark.parametrize("method", ["mean"])
692+
@pytest.mark.parametrize("dtype", ["Float64", "Int64", "boolean"])
693+
def test_ops_consistency_on_empty_nullable(self, method, dtype):
694+
695+
# GH#34814
696+
# consistency for nullable dtypes on empty or ALL-NA mean
697+
698+
# empty series
699+
eser = Series([], dtype=dtype)
700+
result = getattr(eser, method)()
701+
assert result is pd.NA
702+
703+
# ALL-NA series
704+
nser = Series([np.nan], dtype=dtype)
705+
result = getattr(nser, method)()
706+
assert result is pd.NA
707+
691708
@pytest.mark.parametrize("method", ["mean", "median", "std", "var"])
692709
def test_ops_consistency_on_empty(self, method):
693710

0 commit comments

Comments
 (0)