Skip to content

Commit bce70c8

Browse files
Backport PR #43467: Make group_mean compatible with NaT (#43630)
Co-authored-by: Alexey Györi <[email protected]>
1 parent 2b41d7b commit bce70c8

File tree

6 files changed

+130
-10
lines changed

6 files changed

+130
-10
lines changed

doc/source/whatsnew/v1.3.4.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ Fixed regressions
2727

2828
Bug fixes
2929
~~~~~~~~~
30-
-
31-
-
30+
- Fixed bug in :meth:`.GroupBy.mean` with datetimelike values including ``NaT`` values returning incorrect results (:issue`:43132`)
3231

3332
.. ---------------------------------------------------------------------------
3433

pandas/_libs/groupby.pyi

+4-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def group_mean(
7474
counts: np.ndarray, # int64_t[::1]
7575
values: np.ndarray, # ndarray[floating, ndim=2]
7676
labels: np.ndarray, # const intp_t[:]
77-
min_count: int = ...,
77+
min_count: int = ..., # Py_ssize_t
78+
is_datetimelike: bool = ..., # bint
79+
mask: np.ndarray | None = ...,
80+
result_mask: np.ndarray | None = ...,
7881
) -> None: ...
7982
def group_ohlc(
8083
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+41-5
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,45 @@ def group_mean(floating[:, ::1] out,
669669
int64_t[::1] counts,
670670
ndarray[floating, ndim=2] values,
671671
const intp_t[::1] labels,
672-
Py_ssize_t min_count=-1) -> None:
672+
Py_ssize_t min_count=-1,
673+
bint is_datetimelike=False,
674+
const uint8_t[:, ::1] mask=None,
675+
uint8_t[:, ::1] result_mask=None
676+
) -> None:
677+
"""
678+
Compute the mean per label given a label assignment for each value.
679+
NaN values are ignored.
680+
681+
Parameters
682+
----------
683+
out : np.ndarray[floating]
684+
Values into which this method will write its results.
685+
counts : np.ndarray[int64]
686+
A zeroed array of the same shape as labels,
687+
populated by group sizes during algorithm.
688+
values : np.ndarray[floating]
689+
2-d array of the values to find the mean of.
690+
labels : np.ndarray[np.intp]
691+
Array containing unique label for each group, with its
692+
ordering matching up to the corresponding record in `values`.
693+
min_count : Py_ssize_t
694+
Only used in add and prod. Always -1.
695+
is_datetimelike : bool
696+
True if `values` contains datetime-like entries.
697+
mask : ndarray[bool, ndim=2], optional
698+
Not used.
699+
result_mask : ndarray[bool, ndim=2], optional
700+
Not used.
701+
702+
Notes
703+
-----
704+
This method modifies the `out` parameter rather than returning an object.
705+
`counts` is modified to hold group sizes
706+
"""
707+
673708
cdef:
674709
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
675-
floating val, count, y, t
710+
floating val, count, y, t, nan_val
676711
floating[:, ::1] sumx, compensation
677712
int64_t[:, ::1] nobs
678713
Py_ssize_t len_values = len(values), len_labels = len(labels)
@@ -682,12 +717,13 @@ def group_mean(floating[:, ::1] out,
682717
if len_values != len_labels:
683718
raise ValueError("len(index) != len(labels)")
684719

685-
nobs = np.zeros((<object>out).shape, dtype=np.int64)
686720
# the below is equivalent to `np.zeros_like(out)` but faster
721+
nobs = np.zeros((<object>out).shape, dtype=np.int64)
687722
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
688723
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
689724

690725
N, K = (<object>values).shape
726+
nan_val = NPY_NAT if is_datetimelike else NAN
691727

692728
with nogil:
693729
for i in range(N):
@@ -699,7 +735,7 @@ def group_mean(floating[:, ::1] out,
699735
for j in range(K):
700736
val = values[i, j]
701737
# not nan
702-
if val == val:
738+
if val == val and not (is_datetimelike and val == NPY_NAT):
703739
nobs[lab, j] += 1
704740
y = val - compensation[lab, j]
705741
t = sumx[lab, j] + y
@@ -710,7 +746,7 @@ def group_mean(floating[:, ::1] out,
710746
for j in range(K):
711747
count = nobs[i, j]
712748
if nobs[i, j] == 0:
713-
out[i, j] = NAN
749+
out[i, j] = nan_val
714750
else:
715751
out[i, j] = sumx[i, j] / count
716752

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def _call_cython_op(
544544
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
545545
if self.kind == "aggregate":
546546
counts = np.zeros(ngroups, dtype=np.int64)
547-
if self.how in ["min", "max"]:
547+
if self.how in ["min", "max", "mean"]:
548548
func(
549549
result,
550550
counts,

pandas/tests/groupby/aggregate/test_aggregate.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MultiIndex,
2121
Series,
2222
concat,
23+
to_datetime,
2324
)
2425
import pandas._testing as tm
2526
from pandas.core.base import SpecificationError
@@ -66,7 +67,6 @@ def test_agg_ser_multi_key(df):
6667

6768

6869
def test_groupby_aggregation_mixed_dtype():
69-
7070
# GH 6212
7171
expected = DataFrame(
7272
{
@@ -1274,3 +1274,35 @@ def func(ser):
12741274

12751275
expected = DataFrame([[1.0]], index=[1])
12761276
tm.assert_frame_equal(res, expected)
1277+
1278+
1279+
def test_group_mean_timedelta_nat():
1280+
# GH43132
1281+
data = Series(["1 day", "3 days", "NaT"], dtype="timedelta64[ns]")
1282+
expected = Series(["2 days"], dtype="timedelta64[ns]")
1283+
1284+
result = data.groupby([0, 0, 0]).mean()
1285+
1286+
tm.assert_series_equal(result, expected)
1287+
1288+
1289+
@pytest.mark.parametrize(
1290+
"input_data, expected_output",
1291+
[
1292+
( # no timezone
1293+
["2021-01-01T00:00", "NaT", "2021-01-01T02:00"],
1294+
["2021-01-01T01:00"],
1295+
),
1296+
( # timezone
1297+
["2021-01-01T00:00-0100", "NaT", "2021-01-01T02:00-0100"],
1298+
["2021-01-01T01:00-0100"],
1299+
),
1300+
],
1301+
)
1302+
def test_group_mean_datetime64_nat(input_data, expected_output):
1303+
# GH43132
1304+
data = to_datetime(Series(input_data))
1305+
expected = to_datetime(Series(expected_output))
1306+
1307+
result = data.groupby([0, 0, 0]).mean()
1308+
tm.assert_series_equal(result, expected)

pandas/tests/groupby/test_libgroupby.py

+50
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pytest
23

34
from pandas._libs import groupby as libgroupby
45
from pandas._libs.groupby import (
56
group_cumprod_float64,
67
group_cumsum,
8+
group_mean,
79
group_var,
810
)
911

@@ -234,3 +236,51 @@ def test_cython_group_transform_algos():
234236
]
235237
)
236238
tm.assert_numpy_array_equal(actual[:, 0].view("m8[ns]"), expected)
239+
240+
241+
def test_cython_group_mean_datetimelike():
242+
actual = np.zeros(shape=(1, 1), dtype="float64")
243+
counts = np.array([0], dtype="int64")
244+
data = (
245+
np.array(
246+
[np.timedelta64(2, "ns"), np.timedelta64(4, "ns"), np.timedelta64("NaT")],
247+
dtype="m8[ns]",
248+
)[:, None]
249+
.view("int64")
250+
.astype("float64")
251+
)
252+
labels = np.zeros(len(data), dtype=np.intp)
253+
254+
group_mean(actual, counts, data, labels, is_datetimelike=True)
255+
256+
tm.assert_numpy_array_equal(actual[:, 0], np.array([3], dtype="float64"))
257+
258+
259+
def test_cython_group_mean_wrong_min_count():
260+
actual = np.zeros(shape=(1, 1), dtype="float64")
261+
counts = np.zeros(1, dtype="int64")
262+
data = np.zeros(1, dtype="float64")[:, None]
263+
labels = np.zeros(1, dtype=np.intp)
264+
265+
with pytest.raises(AssertionError, match="min_count"):
266+
group_mean(actual, counts, data, labels, is_datetimelike=True, min_count=0)
267+
268+
269+
def test_cython_group_mean_not_datetimelike_but_has_NaT_values():
270+
actual = np.zeros(shape=(1, 1), dtype="float64")
271+
counts = np.array([0], dtype="int64")
272+
data = (
273+
np.array(
274+
[np.timedelta64("NaT"), np.timedelta64("NaT")],
275+
dtype="m8[ns]",
276+
)[:, None]
277+
.view("int64")
278+
.astype("float64")
279+
)
280+
labels = np.zeros(len(data), dtype=np.intp)
281+
282+
group_mean(actual, counts, data, labels, is_datetimelike=False)
283+
284+
tm.assert_numpy_array_equal(
285+
actual[:, 0], np.array(np.divide(np.add(data[0], data[1]), 2), dtype="float64")
286+
)

0 commit comments

Comments
 (0)