Skip to content

Commit 011c016

Browse files
authored
Make group_mean compatible with NaT (pandas-dev#43467)
1 parent 0de6f8b commit 011c016

File tree

6 files changed

+130
-10
lines changed

6 files changed

+130
-10
lines changed

doc/source/whatsnew/v1.3.4.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ Fixed regressions
2727

2828
Bug fixes
2929
~~~~~~~~~
30-
-
31-
-
30+
- Fixed bug in :meth:`.GroupBy.mean` with datetimelike values including ``NaT`` values returning incorrect results (:issue`:43132`)
3231

3332
.. ---------------------------------------------------------------------------
3433

pandas/_libs/groupby.pyi

+4-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def group_mean(
7777
counts: np.ndarray, # int64_t[::1]
7878
values: np.ndarray, # ndarray[floating, ndim=2]
7979
labels: np.ndarray, # const intp_t[:]
80-
min_count: int = ...,
80+
min_count: int = ..., # Py_ssize_t
81+
is_datetimelike: bool = ..., # bint
82+
mask: np.ndarray | None = ...,
83+
result_mask: np.ndarray | None = ...,
8184
) -> None: ...
8285
def group_ohlc(
8386
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+41-5
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,45 @@ def group_mean(floating[:, ::1] out,
673673
int64_t[::1] counts,
674674
ndarray[floating, ndim=2] values,
675675
const intp_t[::1] labels,
676-
Py_ssize_t min_count=-1) -> None:
676+
Py_ssize_t min_count=-1,
677+
bint is_datetimelike=False,
678+
const uint8_t[:, ::1] mask=None,
679+
uint8_t[:, ::1] result_mask=None
680+
) -> None:
681+
"""
682+
Compute the mean per label given a label assignment for each value.
683+
NaN values are ignored.
684+
685+
Parameters
686+
----------
687+
out : np.ndarray[floating]
688+
Values into which this method will write its results.
689+
counts : np.ndarray[int64]
690+
A zeroed array of the same shape as labels,
691+
populated by group sizes during algorithm.
692+
values : np.ndarray[floating]
693+
2-d array of the values to find the mean of.
694+
labels : np.ndarray[np.intp]
695+
Array containing unique label for each group, with its
696+
ordering matching up to the corresponding record in `values`.
697+
min_count : Py_ssize_t
698+
Only used in add and prod. Always -1.
699+
is_datetimelike : bool
700+
True if `values` contains datetime-like entries.
701+
mask : ndarray[bool, ndim=2], optional
702+
Not used.
703+
result_mask : ndarray[bool, ndim=2], optional
704+
Not used.
705+
706+
Notes
707+
-----
708+
This method modifies the `out` parameter rather than returning an object.
709+
`counts` is modified to hold group sizes
710+
"""
711+
677712
cdef:
678713
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
679-
floating val, count, y, t
714+
floating val, count, y, t, nan_val
680715
floating[:, ::1] sumx, compensation
681716
int64_t[:, ::1] nobs
682717
Py_ssize_t len_values = len(values), len_labels = len(labels)
@@ -686,12 +721,13 @@ def group_mean(floating[:, ::1] out,
686721
if len_values != len_labels:
687722
raise ValueError("len(index) != len(labels)")
688723

689-
nobs = np.zeros((<object>out).shape, dtype=np.int64)
690724
# the below is equivalent to `np.zeros_like(out)` but faster
725+
nobs = np.zeros((<object>out).shape, dtype=np.int64)
691726
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
692727
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
693728

694729
N, K = (<object>values).shape
730+
nan_val = NPY_NAT if is_datetimelike else NAN
695731

696732
with nogil:
697733
for i in range(N):
@@ -703,7 +739,7 @@ def group_mean(floating[:, ::1] out,
703739
for j in range(K):
704740
val = values[i, j]
705741
# not nan
706-
if val == val:
742+
if val == val and not (is_datetimelike and val == NPY_NAT):
707743
nobs[lab, j] += 1
708744
y = val - compensation[lab, j]
709745
t = sumx[lab, j] + y
@@ -714,7 +750,7 @@ def group_mean(floating[:, ::1] out,
714750
for j in range(K):
715751
count = nobs[i, j]
716752
if nobs[i, j] == 0:
717-
out[i, j] = NAN
753+
out[i, j] = nan_val
718754
else:
719755
out[i, j] = sumx[i, j] / count
720756

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def _call_cython_op(
514514
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
515515
if self.kind == "aggregate":
516516
counts = np.zeros(ngroups, dtype=np.int64)
517-
if self.how in ["min", "max"]:
517+
if self.how in ["min", "max", "mean"]:
518518
func(
519519
result,
520520
counts,

pandas/tests/groupby/aggregate/test_aggregate.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MultiIndex,
2121
Series,
2222
concat,
23+
to_datetime,
2324
)
2425
import pandas._testing as tm
2526
from pandas.core.base import SpecificationError
@@ -66,7 +67,6 @@ def test_agg_ser_multi_key(df):
6667

6768

6869
def test_groupby_aggregation_mixed_dtype():
69-
7070
# GH 6212
7171
expected = DataFrame(
7272
{
@@ -1259,3 +1259,35 @@ def func(ser):
12591259

12601260
expected = DataFrame([[1.0]], index=[1])
12611261
tm.assert_frame_equal(res, expected)
1262+
1263+
1264+
def test_group_mean_timedelta_nat():
1265+
# GH43132
1266+
data = Series(["1 day", "3 days", "NaT"], dtype="timedelta64[ns]")
1267+
expected = Series(["2 days"], dtype="timedelta64[ns]")
1268+
1269+
result = data.groupby([0, 0, 0]).mean()
1270+
1271+
tm.assert_series_equal(result, expected)
1272+
1273+
1274+
@pytest.mark.parametrize(
1275+
"input_data, expected_output",
1276+
[
1277+
( # no timezone
1278+
["2021-01-01T00:00", "NaT", "2021-01-01T02:00"],
1279+
["2021-01-01T01:00"],
1280+
),
1281+
( # timezone
1282+
["2021-01-01T00:00-0100", "NaT", "2021-01-01T02:00-0100"],
1283+
["2021-01-01T01:00-0100"],
1284+
),
1285+
],
1286+
)
1287+
def test_group_mean_datetime64_nat(input_data, expected_output):
1288+
# GH43132
1289+
data = to_datetime(Series(input_data))
1290+
expected = to_datetime(Series(expected_output))
1291+
1292+
result = data.groupby([0, 0, 0]).mean()
1293+
tm.assert_series_equal(result, expected)

pandas/tests/groupby/test_libgroupby.py

+50
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pytest
23

34
from pandas._libs import groupby as libgroupby
45
from pandas._libs.groupby import (
56
group_cumprod_float64,
67
group_cumsum,
8+
group_mean,
79
group_var,
810
)
911

@@ -234,3 +236,51 @@ def test_cython_group_transform_algos():
234236
]
235237
)
236238
tm.assert_numpy_array_equal(actual[:, 0].view("m8[ns]"), expected)
239+
240+
241+
def test_cython_group_mean_datetimelike():
242+
actual = np.zeros(shape=(1, 1), dtype="float64")
243+
counts = np.array([0], dtype="int64")
244+
data = (
245+
np.array(
246+
[np.timedelta64(2, "ns"), np.timedelta64(4, "ns"), np.timedelta64("NaT")],
247+
dtype="m8[ns]",
248+
)[:, None]
249+
.view("int64")
250+
.astype("float64")
251+
)
252+
labels = np.zeros(len(data), dtype=np.intp)
253+
254+
group_mean(actual, counts, data, labels, is_datetimelike=True)
255+
256+
tm.assert_numpy_array_equal(actual[:, 0], np.array([3], dtype="float64"))
257+
258+
259+
def test_cython_group_mean_wrong_min_count():
260+
actual = np.zeros(shape=(1, 1), dtype="float64")
261+
counts = np.zeros(1, dtype="int64")
262+
data = np.zeros(1, dtype="float64")[:, None]
263+
labels = np.zeros(1, dtype=np.intp)
264+
265+
with pytest.raises(AssertionError, match="min_count"):
266+
group_mean(actual, counts, data, labels, is_datetimelike=True, min_count=0)
267+
268+
269+
def test_cython_group_mean_not_datetimelike_but_has_NaT_values():
270+
actual = np.zeros(shape=(1, 1), dtype="float64")
271+
counts = np.array([0], dtype="int64")
272+
data = (
273+
np.array(
274+
[np.timedelta64("NaT"), np.timedelta64("NaT")],
275+
dtype="m8[ns]",
276+
)[:, None]
277+
.view("int64")
278+
.astype("float64")
279+
)
280+
labels = np.zeros(len(data), dtype=np.intp)
281+
282+
group_mean(actual, counts, data, labels, is_datetimelike=False)
283+
284+
tm.assert_numpy_array_equal(
285+
actual[:, 0], np.array(np.divide(np.add(data[0], data[1]), 2), dtype="float64")
286+
)

0 commit comments

Comments
 (0)