Skip to content

Commit 128fc9a

Browse files
authored
ENH: support td64/dt64 in GroupBy.std (#51333)
* ENH: support td64/dt64 in GroupBy.std * troubleshoot 32bit builds * fix test * troubleshoot 32bit builds * mypy fixup * troubleshoot CI
1 parent 70515df commit 128fc9a

File tree

7 files changed

+67
-9
lines changed

7 files changed

+67
-9
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ Other enhancements
284284
- Added support for ``dt`` accessor methods when using :class:`ArrowDtype` with a ``pyarrow.timestamp`` type (:issue:`50954`)
285285
- :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`)
286286
- :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
287+
- :meth:`.DataFrameGroupBy.std`, :meth:`.SeriesGroupBy.std` now support datetime64, timedelta64, and :class:`DatetimeTZDtype` dtypes (:issue:`48481`)
287288
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
288289
- :func:`.testing.assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
289290
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)

pandas/_libs/groupby.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def group_var(
8585
ddof: int = ..., # int64_t
8686
mask: np.ndarray | None = ...,
8787
result_mask: np.ndarray | None = ...,
88+
is_datetimelike: bool = ...,
8889
) -> None: ...
8990
def group_mean(
9091
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+7-1
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def group_var(
818818
int64_t ddof=1,
819819
const uint8_t[:, ::1] mask=None,
820820
uint8_t[:, ::1] result_mask=None,
821+
bint is_datetimelike=False,
821822
) -> None:
822823
cdef:
823824
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
@@ -852,8 +853,13 @@ def group_var(
852853

853854
if uses_mask:
854855
isna_entry = mask[i, j]
856+
elif is_datetimelike:
857+
# With group_var, we cannot just use _treat_as_na bc
858+
# datetimelike dtypes get cast to float64 instead of
859+
# to int64.
860+
isna_entry = val == NPY_NAT
855861
else:
856-
isna_entry = _treat_as_na(val, False)
862+
isna_entry = _treat_as_na(val, is_datetimelike)
857863

858864
if not isna_entry:
859865
nobs[lab, j] += 1

pandas/core/groupby/groupby.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class providing the base-class of operations.
3030
cast,
3131
final,
3232
)
33+
import warnings
3334

3435
import numpy as np
3536

@@ -97,8 +98,10 @@ class providing the base-class of operations.
9798
BaseMaskedArray,
9899
BooleanArray,
99100
Categorical,
101+
DatetimeArray,
100102
ExtensionArray,
101103
FloatingArray,
104+
TimedeltaArray,
102105
)
103106
from pandas.core.base import (
104107
PandasObject,
@@ -3724,7 +3727,10 @@ def blk_func(values: ArrayLike) -> ArrayLike:
37243727
counts = np.zeros(ngroups, dtype=np.int64)
37253728
func = partial(func, counts=counts)
37263729

3730+
is_datetimelike = values.dtype.kind in ["m", "M"]
37273731
vals = values
3732+
if is_datetimelike and how == "std":
3733+
vals = vals.view("i8")
37283734
if pre_processing:
37293735
vals, inferences = pre_processing(vals)
37303736

@@ -3747,7 +3753,11 @@ def blk_func(values: ArrayLike) -> ArrayLike:
37473753
result_mask = np.zeros(result.shape, dtype=np.bool_)
37483754
func = partial(func, result_mask=result_mask)
37493755

3750-
func(**kwargs) # Call func to modify result in place
3756+
# Call func to modify result in place
3757+
if how == "std":
3758+
func(**kwargs, is_datetimelike=is_datetimelike)
3759+
else:
3760+
func(**kwargs)
37513761

37523762
if values.ndim == 1:
37533763
assert result.shape[1] == 1, result.shape
@@ -3761,6 +3771,15 @@ def blk_func(values: ArrayLike) -> ArrayLike:
37613771

37623772
result = post_processing(result, inferences, **pp_kwargs)
37633773

3774+
if how == "std" and is_datetimelike:
3775+
values = cast("DatetimeArray | TimedeltaArray", values)
3776+
unit = values.unit
3777+
with warnings.catch_warnings():
3778+
# suppress "RuntimeWarning: invalid value encountered in cast"
3779+
warnings.filterwarnings("ignore")
3780+
result = result.astype(np.int64, copy=False)
3781+
result = result.view(f"m8[{unit}]")
3782+
37643783
return result.T
37653784

37663785
# Operate block-wise instead of column-by-column

pandas/tests/groupby/test_groupby.py

+27
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,33 @@ def test_repr():
3737
assert result == expected
3838

3939

40+
def test_groupby_std_datetimelike():
41+
# GH#48481
42+
tdi = pd.timedelta_range("1 Day", periods=10000)
43+
ser = Series(tdi)
44+
ser[::5] *= 2 # get different std for different groups
45+
46+
df = ser.to_frame("A")
47+
48+
df["B"] = ser + Timestamp(0)
49+
df["C"] = ser + Timestamp(0, tz="UTC")
50+
df.iloc[-1] = pd.NaT # last group includes NaTs
51+
52+
gb = df.groupby(list(range(5)) * 2000)
53+
54+
result = gb.std()
55+
56+
# Note: this does not _exactly_ match what we would get if we did
57+
# [gb.get_group(i).std() for i in gb.groups]
58+
# but it _does_ match the floating point error we get doing the
59+
# same operation on int64 data xref GH#51332
60+
td1 = Timedelta("2887 days 11:21:02.326710176")
61+
td4 = Timedelta("2886 days 00:42:34.664668096")
62+
exp_ser = Series([td1 * 2, td1, td1, td1, td4], index=np.arange(5))
63+
expected = DataFrame({"A": exp_ser, "B": exp_ser, "C": exp_ser})
64+
tm.assert_frame_equal(result, expected)
65+
66+
4067
@pytest.mark.parametrize("dtype", ["int64", "int32", "float64", "float32"])
4168
def test_basic(dtype):
4269

pandas/tests/groupby/test_raises.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,11 @@ def test_groupby_raises_datetime(how, by, groupby_series, groupby_func):
224224
"prod": (TypeError, "datetime64 type does not support prod"),
225225
"quantile": (None, ""),
226226
"rank": (None, ""),
227-
"sem": (TypeError, "Cannot cast DatetimeArray to dtype float64"),
227+
"sem": (None, ""),
228228
"shift": (None, ""),
229229
"size": (None, ""),
230230
"skew": (TypeError, r"dtype datetime64\[ns\] does not support reduction"),
231-
"std": (TypeError, "Cannot cast DatetimeArray to dtype float64"),
231+
"std": (None, ""),
232232
"sum": (TypeError, "datetime64 type does not support sum operations"),
233233
"var": (None, ""),
234234
}[groupby_func]

pandas/tests/resample/test_resample_api.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -405,12 +405,16 @@ def test_agg():
405405
expected = pd.concat([a_mean, a_std, b_mean, b_std], axis=1)
406406
expected.columns = pd.MultiIndex.from_product([["A", "B"], ["mean", "std"]])
407407
for t in cases:
408-
# In case 2, "date" is an index and a column, so agg still tries to agg
408+
# In case 2, "date" is an index and a column, so get included in the agg
409409
if t == cases[2]:
410-
# .var on dt64 column raises
411-
msg = "Cannot cast DatetimeArray to dtype float64"
412-
with pytest.raises(TypeError, match=msg):
413-
t.aggregate([np.mean, np.std])
410+
date_mean = t["date"].mean()
411+
date_std = t["date"].std()
412+
exp = pd.concat([date_mean, date_std, expected], axis=1)
413+
exp.columns = pd.MultiIndex.from_product(
414+
[["date", "A", "B"], ["mean", "std"]]
415+
)
416+
result = t.aggregate([np.mean, np.std])
417+
tm.assert_frame_equal(result, exp)
414418
else:
415419
result = t.aggregate([np.mean, np.std])
416420
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)