diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index c7573ee860744..2d82ffd95adb6 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -160,7 +160,7 @@ Deprecations Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ - Performance improvement in :meth:`IntervalIndex.isin` (:issue:`38353`) -- +- Performance improvement in :meth:`Series.mean` for nullable data types (:issue:`34814`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/array_algos/masked_reductions.py b/pandas/core/array_algos/masked_reductions.py index bce6f1aafb2c5..ec0f2c61e0a29 100644 --- a/pandas/core/array_algos/masked_reductions.py +++ b/pandas/core/array_algos/masked_reductions.py @@ -107,3 +107,12 @@ def min(values: np.ndarray, mask: np.ndarray, *, skipna: bool = True): def max(values: np.ndarray, mask: np.ndarray, *, skipna: bool = True): return _minmax(np.max, values=values, mask=mask, skipna=skipna) + + +def mean(values: np.ndarray, mask: np.ndarray, skipna: bool = True): + if not values.size or mask.all(): + return libmissing.NA + _sum = _sumprod(np.sum, values=values, mask=mask, skipna=skipna) + count = np.count_nonzero(~mask) + mean_value = _sum / count + return mean_value diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 7821f103909da..3cf25847ed3d0 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -394,7 +394,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): data = self._data mask = self._mask - if name in {"sum", "prod", "min", "max"}: + if name in {"sum", "prod", "min", "max", "mean"}: op = getattr(masked_reductions, name) return op(data, mask, skipna=skipna, **kwargs) diff --git a/pandas/tests/reductions/test_reductions.py b/pandas/tests/reductions/test_reductions.py index 8c2297699807d..cc5dc675c36e6 100644 --- a/pandas/tests/reductions/test_reductions.py +++ b/pandas/tests/reductions/test_reductions.py @@ -679,6 +679,23 @@ def test_empty_multi(self, method, unit): expected = Series([1, np.nan], index=["a", "b"]) tm.assert_series_equal(result, expected) + @pytest.mark.parametrize("method", ["mean"]) + @pytest.mark.parametrize("dtype", ["Float64", "Int64", "boolean"]) + def test_ops_consistency_on_empty_nullable(self, method, dtype): + + # GH#34814 + # consistency for nullable dtypes on empty or ALL-NA mean + + # empty series + eser = Series([], dtype=dtype) + result = getattr(eser, method)() + assert result is pd.NA + + # ALL-NA series + nser = Series([np.nan], dtype=dtype) + result = getattr(nser, method)() + assert result is pd.NA + @pytest.mark.parametrize("method", ["mean", "median", "std", "var"]) def test_ops_consistency_on_empty(self, method):