Skip to content

Commit 97d96c4

Browse files
BUG: groupby mean fails for complex (#43756)
1 parent 69059e8 commit 97d96c4

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ Groupby/resample/rolling
489489
- Bug in :meth:`DataFrame.rolling.corr` when the :class:`DataFrame` columns was a :class:`MultiIndex` (:issue:`21157`)
490490
- Bug in :meth:`DataFrame.groupby.rolling` when specifying ``on`` and calling ``__getitem__`` would subsequently return incorrect results (:issue:`43355`)
491491
- Bug in :meth:`GroupBy.apply` with time-based :class:`Grouper` objects incorrectly raising ``ValueError`` in corner cases where the grouping vector contains a ``NaT`` (:issue:`43500`, :issue:`43515`)
492+
- Bug in :meth:`GroupBy.mean` failing with ``complex`` dtype (:issue:`43701`)
492493

493494
Reshaping
494495
^^^^^^^^^

pandas/_libs/groupby.pyx

+10-4
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ ctypedef fused add_t:
481481
complex128_t
482482
object
483483

484+
ctypedef fused mean_t:
485+
float64_t
486+
float32_t
487+
complex64_t
488+
complex128_t
489+
484490

485491
@cython.wraparound(False)
486492
@cython.boundscheck(False)
@@ -670,9 +676,9 @@ def group_var(floating[:, ::1] out,
670676

671677
@cython.wraparound(False)
672678
@cython.boundscheck(False)
673-
def group_mean(floating[:, ::1] out,
679+
def group_mean(mean_t[:, ::1] out,
674680
int64_t[::1] counts,
675-
ndarray[floating, ndim=2] values,
681+
ndarray[mean_t, ndim=2] values,
676682
const intp_t[::1] labels,
677683
Py_ssize_t min_count=-1,
678684
bint is_datetimelike=False,
@@ -712,8 +718,8 @@ def group_mean(floating[:, ::1] out,
712718

713719
cdef:
714720
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
715-
floating val, count, y, t, nan_val
716-
floating[:, ::1] sumx, compensation
721+
mean_t val, count, y, t, nan_val
722+
mean_t[:, ::1] sumx, compensation
717723
int64_t[:, ::1] nobs
718724
Py_ssize_t len_values = len(values), len_labels = len(labels)
719725

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
307307
elif how in ["mean", "median", "var"]:
308308
if isinstance(dtype, (BooleanDtype, _IntegerDtype)):
309309
return Float64Dtype()
310-
elif is_float_dtype(dtype):
310+
elif is_float_dtype(dtype) or is_complex_dtype(dtype):
311311
return dtype
312312
elif is_numeric_dtype(dtype):
313313
return np.dtype(np.float64)

pandas/tests/groupby/aggregate/test_aggregate.py

+20
Original file line numberDiff line numberDiff line change
@@ -1357,3 +1357,23 @@ def test_group_mean_datetime64_nat(input_data, expected_output):
13571357

13581358
result = data.groupby([0, 0, 0]).mean()
13591359
tm.assert_series_equal(result, expected)
1360+
1361+
1362+
@pytest.mark.parametrize(
1363+
"func, output", [("mean", [8 + 18j, 10 + 22j]), ("sum", [40 + 90j, 50 + 110j])]
1364+
)
1365+
def test_groupby_complex(func, output):
1366+
# GH#43701
1367+
data = Series(np.arange(20).reshape(10, 2).dot([1, 2j]))
1368+
result = data.groupby(data.index % 2).agg(func)
1369+
expected = Series(output)
1370+
tm.assert_series_equal(result, expected)
1371+
1372+
1373+
@pytest.mark.parametrize("func", ["min", "max", "var"])
1374+
def test_groupby_complex_raises(func):
1375+
# GH#43701
1376+
data = Series(np.arange(20).reshape(10, 2).dot([1, 2j]))
1377+
msg = "No matching signature found"
1378+
with pytest.raises(TypeError, match=msg):
1379+
data.groupby(data.index % 2).agg(func)

0 commit comments

Comments
 (0)