Skip to content

Commit e1cadfa

Browse files
jbrockmendeljreback
authored andcommitted
REF: make _aggregate_series_pure_python extraction behave like the cython version (#29641)
1 parent b9b462c commit e1cadfa

File tree

3 files changed

+43
-21
lines changed

3 files changed

+43
-21
lines changed

pandas/core/groupby/groupby.py

+7-18
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class providing the base-class of operations.
3131
from pandas.core.dtypes.common import (
3232
ensure_float,
3333
is_datetime64_dtype,
34-
is_datetime64tz_dtype,
3534
is_extension_array_dtype,
3635
is_integer_dtype,
3736
is_numeric_dtype,
@@ -45,7 +44,6 @@ class providing the base-class of operations.
4544
from pandas.core.arrays import Categorical, try_cast_to_ea
4645
from pandas.core.base import DataError, PandasObject, SelectionMixin
4746
import pandas.core.common as com
48-
from pandas.core.construction import extract_array
4947
from pandas.core.frame import DataFrame
5048
from pandas.core.generic import NDFrame
5149
from pandas.core.groupby import base, ops
@@ -790,22 +788,11 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
790788
dtype = obj.dtype
791789

792790
if not is_scalar(result):
793-
if is_datetime64tz_dtype(dtype):
794-
# GH 23683
795-
# Prior results _may_ have been generated in UTC.
796-
# Ensure we localize to UTC first before converting
797-
# to the target timezone
798-
arr = extract_array(obj)
799-
try:
800-
result = arr._from_sequence(result, dtype="datetime64[ns, UTC]")
801-
result = result.astype(dtype)
802-
except TypeError:
803-
# _try_cast was called at a point where the result
804-
# was already tz-aware
805-
pass
806-
elif is_extension_array_dtype(dtype):
791+
if is_extension_array_dtype(dtype) and dtype.kind != "M":
807792
# The function can return something of any type, so check
808-
# if the type is compatible with the calling EA.
793+
# if the type is compatible with the calling EA.
794+
# datetime64tz is handled correctly in agg_series,
795+
# so is excluded here.
809796

810797
# return the same type (Series) as our caller
811798
cls = dtype.construct_array_type()
@@ -872,7 +859,9 @@ def _cython_agg_general(
872859
if numeric_only and not is_numeric:
873860
continue
874861

875-
result, names = self.grouper.aggregate(obj.values, how, min_count=min_count)
862+
result, names = self.grouper.aggregate(
863+
obj._values, how, min_count=min_count
864+
)
876865
output[name] = self._try_cast(result, obj)
877866

878867
if len(output) == 0:

pandas/core/groupby/ops.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -604,11 +604,11 @@ def agg_series(self, obj: Series, func):
604604
# SeriesGrouper would raise if we were to call _aggregate_series_fast
605605
return self._aggregate_series_pure_python(obj, func)
606606

607-
elif is_extension_array_dtype(obj.dtype) and obj.dtype.kind != "M":
607+
elif is_extension_array_dtype(obj.dtype):
608608
# _aggregate_series_fast would raise TypeError when
609609
# calling libreduction.Slider
610+
# In the datetime64tz case it would incorrectly cast to tz-naive
610611
# TODO: can we get a performant workaround for EAs backed by ndarray?
611-
# TODO: is the datetime64tz case supposed to go through here?
612612
return self._aggregate_series_pure_python(obj, func)
613613

614614
elif isinstance(obj.index, MultiIndex):
@@ -657,7 +657,15 @@ def _aggregate_series_pure_python(self, obj: Series, func):
657657
res = func(group)
658658
if result is None:
659659
if isinstance(res, (Series, Index, np.ndarray)):
660-
raise ValueError("Function does not reduce")
660+
if len(res) == 1:
661+
# e.g. test_agg_lambda_with_timezone lambda e: e.head(1)
662+
# FIXME: are we potentially losing import res.index info?
663+
664+
# TODO: use `.item()` if/when we un-deprecate it.
665+
# For non-Series we could just do `res[0]`
666+
res = next(iter(res))
667+
else:
668+
raise ValueError("Function does not reduce")
661669
result = np.empty(ngroups, dtype="O")
662670

663671
counts[label] = group.shape[0]

pandas/tests/groupby/aggregate/test_other.py

+25
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,31 @@ def test_agg_over_numpy_arrays():
454454
tm.assert_frame_equal(result, expected)
455455

456456

457+
def test_agg_tzaware_non_datetime_result():
458+
# discussed in GH#29589, fixed in GH#29641, operating on tzaware values
459+
# with function that is not dtype-preserving
460+
dti = pd.date_range("2012-01-01", periods=4, tz="UTC")
461+
df = pd.DataFrame({"a": [0, 0, 1, 1], "b": dti})
462+
gb = df.groupby("a")
463+
464+
# Case that _does_ preserve the dtype
465+
result = gb["b"].agg(lambda x: x.iloc[0])
466+
expected = pd.Series(dti[::2], name="b")
467+
expected.index.name = "a"
468+
tm.assert_series_equal(result, expected)
469+
470+
# Cases that do _not_ preserve the dtype
471+
result = gb["b"].agg(lambda x: x.iloc[0].year)
472+
expected = pd.Series([2012, 2012], name="b")
473+
expected.index.name = "a"
474+
tm.assert_series_equal(result, expected)
475+
476+
result = gb["b"].agg(lambda x: x.iloc[-1] - x.iloc[0])
477+
expected = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=1)], name="b")
478+
expected.index.name = "a"
479+
tm.assert_series_equal(result, expected)
480+
481+
457482
def test_agg_timezone_round_trip():
458483
# GH 15426
459484
ts = pd.Timestamp("2016-01-01 12:00:00", tz="US/Pacific")

0 commit comments

Comments
 (0)