diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index c8fa6d46c3b7e..ed38c09129679 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -488,6 +488,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrame.rolling.corr` when the :class:`DataFrame` columns was a :class:`MultiIndex` (:issue:`21157`) - Bug in :meth:`DataFrame.groupby.rolling` when specifying ``on`` and calling ``__getitem__`` would subsequently return incorrect results (:issue:`43355`) - Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`) +- Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`) Reshaping ^^^^^^^^^ diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index bbdc5a8287502..6dfed95e7afb6 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -481,6 +481,12 @@ ctypedef fused add_t: complex128_t object +ctypedef fused mean_t: + float64_t + float32_t + complex64_t + complex128_t + @cython.wraparound(False) @cython.boundscheck(False) @@ -670,9 +676,9 @@ def group_var(floating[:, ::1] out, @cython.wraparound(False) @cython.boundscheck(False) -def group_mean(floating[:, ::1] out, +def group_mean(mean_t[:, ::1] out, int64_t[::1] counts, - ndarray[floating, ndim=2] values, + ndarray[mean_t, ndim=2] values, const intp_t[::1] labels, Py_ssize_t min_count=-1, bint is_datetimelike=False, @@ -712,8 +718,8 @@ def group_mean(floating[:, ::1] out, cdef: Py_ssize_t i, j, N, K, lab, ncounts = len(counts) - floating val, count, y, t, nan_val - floating[:, ::1] sumx, compensation + mean_t val, count, y, t, nan_val + mean_t[:, ::1] sumx, compensation int64_t[:, ::1] nobs Py_ssize_t len_values = len(values), len_labels = len(labels) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 07353863cdc0b..89bf14d649a99 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -307,7 +307,7 @@ def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: elif how in ["mean", "median", "var"]: if isinstance(dtype, (BooleanDtype, _IntegerDtype)): return Float64Dtype() - elif is_float_dtype(dtype): + elif is_float_dtype(dtype) or is_complex_dtype(dtype): return dtype elif is_numeric_dtype(dtype): return np.dtype(np.float64) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 8d1962f7cab3b..2c798e543bf6b 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -1357,3 +1357,23 @@ def test_group_mean_datetime64_nat(input_data, expected_output): result = data.groupby([0, 0, 0]).mean() tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, output", [("mean", [8 + 18j, 10 + 22j]), ("sum", [40 + 90j, 50 + 110j])] +) +def test_groupby_complex(func, output): + # GH#43701 + data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) + result = data.groupby(data.index % 2).agg(func) + expected = Series(output) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max", "var"]) +def test_groupby_complex_raises(func): + # GH#43701 + data = Series(np.arange(20).reshape(10, 2).dot([1, 2j])) + msg = "No matching signature found" + with pytest.raises(TypeError, match=msg): + data.groupby(data.index % 2).agg(func)