Skip to content

Commit 5a23e0f

Browse files
jbrockmendelAlexeyGy
authored andcommitted
REF: Groupby._get_cythonized_result operate blockwise in axis==1 case (#43435)
1 parent e7efcca commit 5a23e0f

File tree

6 files changed

+130
-10
lines changed

6 files changed

+130
-10
lines changed

doc/source/whatsnew/v1.3.4.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ Fixed regressions
2323

2424
Bug fixes
2525
~~~~~~~~~
26-
-
27-
-
26+
- Fixed bug in :meth:`.GroupBy.mean` with datetimelike values including ``NaT`` values returning incorrect results (:issue`:43132`)
2827

2928
.. ---------------------------------------------------------------------------
3029

pandas/_libs/groupby.pyi

+4-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def group_mean(
7777
counts: np.ndarray, # int64_t[::1]
7878
values: np.ndarray, # ndarray[floating, ndim=2]
7979
labels: np.ndarray, # const intp_t[:]
80-
min_count: int = ...,
80+
min_count: int = ..., # Py_ssize_t
81+
is_datetimelike: bool = ..., # bint
82+
mask: np.ndarray | None = ...,
83+
result_mask: np.ndarray | None = ...,
8184
) -> None: ...
8285
def group_ohlc(
8386
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+41-5
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,45 @@ def group_mean(floating[:, ::1] out,
673673
int64_t[::1] counts,
674674
ndarray[floating, ndim=2] values,
675675
const intp_t[::1] labels,
676-
Py_ssize_t min_count=-1) -> None:
676+
Py_ssize_t min_count=-1,
677+
bint is_datetimelike=False,
678+
const uint8_t[:, ::1] mask=None,
679+
uint8_t[:, ::1] result_mask=None
680+
) -> None:
681+
"""
682+
Compute the mean per label given a label assignment for each value.
683+
NaN values are ignored.
684+
685+
Parameters
686+
----------
687+
out : np.ndarray[floating]
688+
Values into which this method will write its results.
689+
counts : np.ndarray[int64]
690+
A zeroed array of the same shape as labels,
691+
populated by group sizes during algorithm.
692+
values : np.ndarray[floating]
693+
2-d array of the values to find the mean of.
694+
labels : np.ndarray[np.intp]
695+
Array containing unique label for each group, with its
696+
ordering matching up to the corresponding record in `values`.
697+
min_count : Py_ssize_t
698+
Only used in add and prod. Always -1.
699+
is_datetimelike : bool
700+
True if `values` contains datetime-like entries.
701+
mask : ndarray[bool, ndim=2], optional
702+
Not used.
703+
result_mask : ndarray[bool, ndim=2], optional
704+
Not used.
705+
706+
Notes
707+
-----
708+
This method modifies the `out` parameter rather than returning an object.
709+
`counts` is modified to hold group sizes
710+
"""
711+
677712
cdef:
678713
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
679-
floating val, count, y, t
714+
floating val, count, y, t, nan_val
680715
floating[:, ::1] sumx, compensation
681716
int64_t[:, ::1] nobs
682717
Py_ssize_t len_values = len(values), len_labels = len(labels)
@@ -686,12 +721,13 @@ def group_mean(floating[:, ::1] out,
686721
if len_values != len_labels:
687722
raise ValueError("len(index) != len(labels)")
688723

689-
nobs = np.zeros((<object>out).shape, dtype=np.int64)
690724
# the below is equivalent to `np.zeros_like(out)` but faster
725+
nobs = np.zeros((<object>out).shape, dtype=np.int64)
691726
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
692727
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
693728

694729
N, K = (<object>values).shape
730+
nan_val = NPY_NAT if is_datetimelike else NAN
695731

696732
with nogil:
697733
for i in range(N):
@@ -703,7 +739,7 @@ def group_mean(floating[:, ::1] out,
703739
for j in range(K):
704740
val = values[i, j]
705741
# not nan
706-
if val == val:
742+
if val == val and not (is_datetimelike and val == NPY_NAT):
707743
nobs[lab, j] += 1
708744
y = val - compensation[lab, j]
709745
t = sumx[lab, j] + y
@@ -714,7 +750,7 @@ def group_mean(floating[:, ::1] out,
714750
for j in range(K):
715751
count = nobs[i, j]
716752
if nobs[i, j] == 0:
717-
out[i, j] = NAN
753+
out[i, j] = nan_val
718754
else:
719755
out[i, j] = sumx[i, j] / count
720756

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def _call_cython_op(
514514
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
515515
if self.kind == "aggregate":
516516
counts = np.zeros(ngroups, dtype=np.int64)
517-
if self.how in ["min", "max"]:
517+
if self.how in ["min", "max", "mean"]:
518518
func(
519519
result,
520520
counts,

pandas/tests/groupby/aggregate/test_aggregate.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def test_agg_ser_multi_key(df):
6666

6767

6868
def test_groupby_aggregation_mixed_dtype():
69-
7069
# GH 6212
7170
expected = DataFrame(
7271
{
@@ -1274,3 +1273,36 @@ def func(ser):
12741273

12751274
expected = DataFrame([[1.0]], index=[1])
12761275
tm.assert_frame_equal(res, expected)
1276+
1277+
1278+
@pytest.mark.parametrize(
1279+
"input_data, expected_output",
1280+
[
1281+
( # timedelta
1282+
{"dtype": "timedelta64[ns]", "values": ["1 day", "3 days", "NaT"]},
1283+
{"dtype": "timedelta64[ns]", "values": ["2 days"]},
1284+
),
1285+
( # datetime
1286+
{
1287+
"dtype": "datetime64[ns]",
1288+
"values": ["2021-01-01T00:00", "NaT", "2021-01-01T02:00"],
1289+
},
1290+
{"dtype": "datetime64[ns]", "values": ["2021-01-01T01:00"]},
1291+
),
1292+
( # timezoned data
1293+
{
1294+
"dtype": "datetime64[ns]",
1295+
"values": ["2021-01-01T00:00-0100", "NaT", "2021-01-01T02:00-0100"],
1296+
},
1297+
{"dtype": "datetime64[ns]", "values": ["2021-01-01T01:00"]},
1298+
),
1299+
],
1300+
)
1301+
def test_group_mean_timedelta_nat(input_data, expected_output):
1302+
data = Series(input_data["values"], dtype=input_data["dtype"])
1303+
1304+
actual = data.groupby([0, 0, 0]).mean()
1305+
1306+
expected = Series(expected_output["values"], dtype=expected_output["dtype"])
1307+
1308+
tm.assert_series_equal(actual, expected)

pandas/tests/groupby/test_libgroupby.py

+50
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pytest
23

34
from pandas._libs import groupby as libgroupby
45
from pandas._libs.groupby import (
56
group_cumprod_float64,
67
group_cumsum,
8+
group_mean,
79
group_var,
810
)
911

@@ -234,3 +236,51 @@ def test_cython_group_transform_algos():
234236
]
235237
)
236238
tm.assert_numpy_array_equal(actual[:, 0].view("m8[ns]"), expected)
239+
240+
241+
def test_cython_group_mean_datetimelike():
242+
actual = np.zeros(shape=(1, 1), dtype="float64")
243+
counts = np.array([0], dtype="int64")
244+
data = (
245+
np.array(
246+
[np.timedelta64(2, "ns"), np.timedelta64(4, "ns"), np.timedelta64("NaT")],
247+
dtype="m8[ns]",
248+
)[:, None]
249+
.view("int64")
250+
.astype("float64")
251+
)
252+
labels = np.zeros(len(data), dtype=np.intp)
253+
254+
group_mean(actual, counts, data, labels, is_datetimelike=True)
255+
256+
tm.assert_numpy_array_equal(actual[:, 0], np.array([3], dtype="float64"))
257+
258+
259+
def test_cython_group_mean_wrong_min_count():
260+
actual = np.zeros(shape=(1, 1), dtype="float64")
261+
counts = np.zeros(1, dtype="int64")
262+
data = np.zeros(1, dtype="float64")[:, None]
263+
labels = np.zeros(1, dtype=np.intp)
264+
265+
with pytest.raises(AssertionError, match="min_count"):
266+
group_mean(actual, counts, data, labels, is_datetimelike=True, min_count=0)
267+
268+
269+
def test_cython_group_mean_not_datetimelike_but_has_NaT_values():
270+
actual = np.zeros(shape=(1, 1), dtype="float64")
271+
counts = np.array([0], dtype="int64")
272+
data = (
273+
np.array(
274+
[np.timedelta64("NaT"), np.timedelta64("NaT")],
275+
dtype="m8[ns]",
276+
)[:, None]
277+
.view("int64")
278+
.astype("float64")
279+
)
280+
labels = np.zeros(len(data), dtype=np.intp)
281+
282+
group_mean(actual, counts, data, labels, is_datetimelike=False)
283+
284+
tm.assert_numpy_array_equal(
285+
actual[:, 0], np.array(np.divide(np.add(data[0], data[1]), 2), dtype="float64")
286+
)

0 commit comments

Comments
 (0)