Skip to content

Commit 7ccffcf

Browse files
committed
Make group_mean compatible with NaT
NaT is the datetime equivalent of NaN and is set to be the lowest possible 64-bit integer -(2**63). Previously, we could not support this value in any `groupby.mean()` calculations which lead to pandas-dev#43132. On a high level, we slightly modify the `group_mean` to not count NaT values. To do so, we introduce the `is_datetimelike` parameter to the function call (already present in other functions, e.g., `group_cumsum`) and refactor and extend `#_treat_as_na` to work with float64. ## Tests This PR adds an integration and two unit tests for the new functionality. In contrast to other tests in classes, I've tried to keep an individual test's scope as small as possible. Additionally, I've taken the liberty to: * Add a docstring for the group_mean algorithm. * Change the algorithm to use guard clauses instead of if/else. * Add a comment that we're using the Kahan summation (the compensation part initially confused me, and I only stumbled upon Kahan when browsing the file). - [x] closes pandas-dev#43132 - [x] tests added / passed - [x] Ensure all linting tests pass, see [here](https://pandas.pydata.org/pandas-docs/dev/development/contributing.html#code-standards) for how to run them - [x] whatsnew entry
1 parent 415dec5 commit 7ccffcf

File tree

6 files changed

+104
-15
lines changed

6 files changed

+104
-15
lines changed

doc/source/whatsnew/v1.3.3.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ Performance improvements
4646
Bug fixes
4747
~~~~~~~~~
4848
- Fixed bug in :meth:`.DataFrameGroupBy.agg` and :meth:`.DataFrameGroupBy.transform` with ``engine="numba"`` where ``index`` data was not being correctly passed into ``func`` (:issue:`43133`)
49-
49+
- :meth:`.GroupBy.mean` now supports ``NaT`` values (:issue:`43132`)
50+
-
5051
.. ---------------------------------------------------------------------------
5152
5253
.. _whatsnew_133.contributors:

pandas/_libs/groupby.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def group_mean(
7575
values: np.ndarray, # ndarray[floating, ndim=2]
7676
labels: np.ndarray, # const intp_t[:]
7777
min_count: int = ...,
78+
is_datetimelike: bool = ...,
7879
) -> None: ...
7980
def group_ohlc(
8081
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+46-14
Original file line numberDiff line numberDiff line change
@@ -675,10 +675,38 @@ def group_mean(floating[:, ::1] out,
675675
int64_t[::1] counts,
676676
ndarray[floating, ndim=2] values,
677677
const intp_t[::1] labels,
678-
Py_ssize_t min_count=-1) -> None:
678+
Py_ssize_t min_count=-1,
679+
bint is_datetimelike=False) -> None:
680+
"""
681+
Compute the mean per label given a label assignment for each value.
682+
NaN values are ignored.
683+
684+
Parameters
685+
----------
686+
out : np.ndarray[floating]
687+
Values into which this method will write its results.
688+
counts : np.ndarray[int64]
689+
A zeroed array of the same shape as labels,
690+
populated by group sizes during algorithm.
691+
values : np.ndarray[floating]
692+
2-d array of the values to find the mean of.
693+
labels : np.ndarray[np.intp]
694+
Array containing unique label for each group, with its
695+
ordering matching up to the corresponding record in `values`.
696+
is_datetimelike : bool
697+
True if `values` contains datetime-like entries.
698+
min_count : Py_ssize_t
699+
Only used in add and prod. Always -1.
700+
701+
Notes
702+
-----
703+
This method modifies the `out` parameter rather than returning an object.
704+
`counts` is modified to hold group sizes
705+
"""
706+
679707
cdef:
680708
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
681-
floating val, count, y, t
709+
floating val, count, y, t, nan_val = NPY_NAT if is_datetimelike else NAN
682710
floating[:, ::1] sumx, compensation
683711
int64_t[:, ::1] nobs
684712
Py_ssize_t len_values = len(values), len_labels = len(labels)
@@ -688,14 +716,15 @@ def group_mean(floating[:, ::1] out,
688716
if len_values != len_labels:
689717
raise ValueError("len(index) != len(labels)")
690718

691-
nobs = np.zeros((<object>out).shape, dtype=np.int64)
692719
# the below is equivalent to `np.zeros_like(out)` but faster
720+
nobs = np.zeros((<object>out).shape, dtype=np.int64)
693721
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
694722
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
695723

696724
N, K = (<object>values).shape
697725

698726
with nogil:
727+
# precompute count and sum
699728
for i in range(N):
700729
lab = labels[i]
701730
if lab < 0:
@@ -704,21 +733,24 @@ def group_mean(floating[:, ::1] out,
704733
counts[lab] += 1
705734
for j in range(K):
706735
val = values[i, j]
707-
# not nan
708-
if val == val:
709-
nobs[lab, j] += 1
710-
y = val - compensation[lab, j]
711-
t = sumx[lab, j] + y
712-
compensation[lab, j] = t - sumx[lab, j] - y
713-
sumx[lab, j] = t
714-
736+
if val !=val or (is_datetimelike and val == NPY_NAT):
737+
continue
738+
# Use Kahan summation to reduce floating-point
739+
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm).
740+
nobs[lab, j] += 1
741+
y = val - compensation[lab, j]
742+
t = sumx[lab, j] + y
743+
compensation[lab, j] = t - sumx[lab, j] - y
744+
sumx[lab, j] = t
745+
746+
# fill output array
715747
for i in range(ncounts):
716748
for j in range(K):
717749
count = nobs[i, j]
718750
if nobs[i, j] == 0:
719-
out[i, j] = NAN
720-
else:
721-
out[i, j] = sumx[i, j] / count
751+
out[i, j] = nan_val
752+
continue
753+
out[i, j] = sumx[i, j] / count
722754

723755

724756
@cython.wraparound(False)

pandas/core/groupby/ops.py

+9
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,15 @@ def _call_cython_op(
525525
result_mask=result_mask,
526526
is_datetimelike=is_datetimelike,
527527
)
528+
elif self.how == "mean":
529+
func(
530+
result,
531+
counts,
532+
values,
533+
comp_ids,
534+
min_count,
535+
is_datetimelike=is_datetimelike,
536+
)
528537
else:
529538
func(result, counts, values, comp_ids, min_count)
530539
else:

pandas/tests/groupby/test_libgroupby.py

+36
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas._libs.groupby import (
55
group_cumprod_float64,
66
group_cumsum,
7+
group_mean,
78
group_var,
89
)
910

@@ -234,3 +235,38 @@ def test_cython_group_transform_algos():
234235
]
235236
)
236237
tm.assert_numpy_array_equal(actual[:, 0].view("m8[ns]"), expected)
238+
239+
240+
def test_cython_group_mean_datetimelike():
241+
actual = np.zeros(shape=(1, 1), dtype="float64")
242+
counts = np.array([0], dtype="int64")
243+
data = (
244+
np.array(
245+
[np.timedelta64(2, "ns"), np.timedelta64(4, "ns"), np.timedelta64("NaT")],
246+
dtype="m8[ns]",
247+
)[:, None]
248+
.view("int64")
249+
.astype("float64")
250+
)
251+
labels = np.zeros(len(data), dtype=np.intp)
252+
253+
group_mean(actual, counts, data, labels, is_datetimelike=True)
254+
255+
tm.assert_numpy_array_equal(actual[:, 0], np.array([3], dtype="float64"))
256+
257+
258+
def test_cython_group_mean_not_datetimelike():
259+
actual = np.zeros(shape=(3, 1), dtype="float64")
260+
counts = np.zeros(2, dtype="int64")
261+
data = np.array(
262+
[
263+
[np.float64(1), np.float64(2), np.float64(3)],
264+
[np.nan, np.float64(2), np.float64(3)],
265+
],
266+
dtype="float64",
267+
)
268+
labels = np.array([0, 0], dtype=np.intp)
269+
270+
group_mean(actual, counts, data, labels, is_datetimelike=False)
271+
272+
tm.assert_numpy_array_equal(actual[:, 0], np.array([1, 2, 3], dtype="float64"))

pandas/tests/groupby/transform/test_transform.py

+10
Original file line numberDiff line numberDiff line change
@@ -1289,3 +1289,13 @@ def test_transform_cumcount():
12891289

12901290
result = grp.transform("cumcount")
12911291
tm.assert_series_equal(result, expected)
1292+
1293+
1294+
def test_group_mean_timedelta_nat():
1295+
series = Series(["1 day", "3 days", pd.NaT], dtype="timedelta64[ns]")
1296+
1297+
group = series.groupby([1, 1, 1])
1298+
result = group.transform("mean")
1299+
1300+
expected = Series(["2 days", "2 days", "2 days"], dtype="timedelta64[ns]")
1301+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)