Skip to content

Commit 3f271b2

Browse files
committed
Make group_mean compatible with NaT
NaT is the datetime equivalent of NaN and is set to be the lowest possible 64 bit integer -(2**63). Previously, we could not support this value in any groupby.mean() calculations which lead to #43132. On a high level, we slightly modify the `group_mean` to not count NaT values. To do so, we introduce the `is_datetimelike` parameter to the function call (already present in other functions, e.g., `group_cumsum`) and refactor and extend `#_treat_as_na` to work with float64. This PR add an additional integration and unit test for the new functionality. In contrast to other tests in classes, I've tried to keep an individual test's scope as small as possible. Additionally, I've taken the liberty to: * Add a docstring for the group_mean algorithm. * Change the algorithm to use guard clauses instead of else/if. * Add a comment that we're using the Kahan summation (the compensation part initially confused me, and I only stumbled upon Kahan when browsing the file). - [x] closes #43132 - [x] tests added / passed - [x] Ensure all linting tests pass, see [here](https://pandas.pydata.org/pandas-docs/dev/development/contributing.html#code-standards) for how to run them - [x] whatsnew entry => different format but it's there
1 parent b17379b commit 3f271b2

File tree

5 files changed

+77
-17
lines changed

5 files changed

+77
-17
lines changed

pandas/_libs/groupby.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def group_mean(
7575
values: np.ndarray, # ndarray[floating, ndim=2]
7676
labels: np.ndarray, # const intp_t[:]
7777
min_count: int = ...,
78+
is_datetimelike: bool = ...,
7879
) -> None: ...
7980
def group_ohlc(
8081
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+45-16
Original file line numberDiff line numberDiff line change
@@ -675,10 +675,36 @@ def group_mean(floating[:, ::1] out,
675675
int64_t[::1] counts,
676676
ndarray[floating, ndim=2] values,
677677
const intp_t[::1] labels,
678-
Py_ssize_t min_count=-1) -> None:
678+
Py_ssize_t min_count=-1,
679+
bint is_datetimelike = False) -> None:
680+
"""
681+
Compute the mean per label given a label assignment for each value. NaN values are ignored.
682+
683+
Parameters
684+
----------
685+
out : np.ndarray[floating]
686+
Values into which this method will write its results.
687+
counts : np.ndarray[int64]
688+
A zeroed array of the same shape as labels, populated by group sizes during algorithm.
689+
values : np.ndarray[floating]
690+
2-d array of the values to find the mean of.
691+
labels : np.ndarray[np.intp]
692+
Array containing unique label for each group, with its
693+
ordering matching up to the corresponding record in `values`.
694+
is_datetimelike : bool
695+
True if `values` contains datetime-like entries.
696+
min_count : Py_ssize_t
697+
Only used in add and prod. Always -1.
698+
699+
Notes
700+
-----
701+
This method modifies the `out` parameter rather than returning an object.
702+
`counts` is modified to hold group sizes
703+
"""
704+
679705
cdef:
680706
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
681-
floating val, count, y, t
707+
floating val, count, y, t, nan_val = NPY_NAT if is_datetimelike else NAN
682708
floating[:, ::1] sumx, compensation
683709
int64_t[:, ::1] nobs
684710
Py_ssize_t len_values = len(values), len_labels = len(labels)
@@ -688,14 +714,15 @@ def group_mean(floating[:, ::1] out,
688714
if len_values != len_labels:
689715
raise ValueError("len(index) != len(labels)")
690716

691-
nobs = np.zeros((<object>out).shape, dtype=np.int64)
692717
# the below is equivalent to `np.zeros_like(out)` but faster
718+
nobs = np.zeros((<object>out).shape, dtype=np.int64)
693719
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
694720
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
695721

696722
N, K = (<object>values).shape
697723

698724
with nogil:
725+
# precompute count and sum
699726
for i in range(N):
700727
lab = labels[i]
701728
if lab < 0:
@@ -704,22 +731,24 @@ def group_mean(floating[:, ::1] out,
704731
counts[lab] += 1
705732
for j in range(K):
706733
val = values[i, j]
707-
# not nan
708-
if val == val:
709-
nobs[lab, j] += 1
710-
y = val - compensation[lab, j]
711-
t = sumx[lab, j] + y
712-
compensation[lab, j] = t - sumx[lab, j] - y
713-
sumx[lab, j] = t
714-
734+
if _treat_as_na(val, is_datetimelike):
735+
continue
736+
# Use Kahan summation to reduce floating-point
737+
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm).
738+
nobs[lab, j] += 1
739+
y = val - compensation[lab, j]
740+
t = sumx[lab, j] + y
741+
compensation[lab, j] = t - sumx[lab, j] - y
742+
sumx[lab, j] = t
743+
744+
# fill output array
715745
for i in range(ncounts):
716746
for j in range(K):
717747
count = nobs[i, j]
718748
if nobs[i, j] == 0:
719-
out[i, j] = NAN
720-
else:
721-
out[i, j] = sumx[i, j] / count
722-
749+
out[i, j] = nan_val
750+
continue
751+
out[i, j] = sumx[i, j] / count
723752

724753
@cython.wraparound(False)
725754
@cython.boundscheck(False)
@@ -900,7 +929,7 @@ cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
900929
# or else cython will raise about gil acquisition.
901930
raise NotImplementedError
902931

903-
elif rank_t is int64_t:
932+
elif rank_t is int64_t or rank_t is float64_t:
904933
return is_datetimelike and val == NPY_NAT
905934
elif rank_t is uint64_t:
906935
# There is no NA value for uint64

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def _call_cython_op(
515515
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
516516
if self.kind == "aggregate":
517517
counts = np.zeros(ngroups, dtype=np.int64)
518-
if self.how in ["min", "max"]:
518+
if self.how in {"min", "max", "mean"}:
519519
func(
520520
result,
521521
counts,

pandas/tests/groupby/test_libgroupby.py

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
group_cumprod_float64,
66
group_cumsum,
77
group_var,
8+
group_mean,
89
)
910

1011
from pandas.core.dtypes.common import ensure_platform_int
@@ -234,3 +235,22 @@ def test_cython_group_transform_algos():
234235
]
235236
)
236237
tm.assert_numpy_array_equal(actual[:, 0].view("m8[ns]"), expected)
238+
239+
240+
def test_cython_group_mean_timedelta():
241+
is_datetimelike = True
242+
actual = np.zeros(shape=(1, 1), dtype="float64")
243+
counts = np.array([0], dtype="int64")
244+
data = (
245+
np.array(
246+
[np.datetime64(2, "ns"), np.datetime64(4, "ns"), np.datetime64("NaT")],
247+
dtype="m8[ns]",
248+
)[:, None]
249+
.view("int64")
250+
.astype("float64")
251+
)
252+
labels = np.zeros(len(data), dtype="int64")
253+
254+
group_mean(actual, counts, data, labels, is_datetimelike=True)
255+
256+
tm.assert_numpy_array_equal(actual[:, 0], np.array([3], dtype="float64"))

pandas/tests/groupby/transform/test_transform.py

+10
Original file line numberDiff line numberDiff line change
@@ -1289,3 +1289,13 @@ def test_transform_cumcount():
12891289

12901290
result = grp.transform("cumcount")
12911291
tm.assert_series_equal(result, expected)
1292+
1293+
1294+
def test_group_mean_timedelta_nat():
1295+
series = Series(["1 day", "3 days", pd.NaT], dtype="timedelta64[ns]")
1296+
1297+
group = series.groupby([1, 1, 1])
1298+
result = group.transform("mean")
1299+
1300+
expected = Series(["2 days", "2 days", "2 days"], dtype="timedelta64[ns]")
1301+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)