Skip to content

REF: make _aggregate_series_pure_python extraction behave like the cython version #29641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 18, 2019
25 changes: 7 additions & 18 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class providing the base-class of operations.
from pandas.core.dtypes.common import (
ensure_float,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_extension_array_dtype,
is_integer_dtype,
is_numeric_dtype,
Expand All @@ -45,7 +44,6 @@ class providing the base-class of operations.
from pandas.core.arrays import Categorical, try_cast_to_ea
from pandas.core.base import DataError, PandasObject, SelectionMixin
import pandas.core.common as com
from pandas.core.construction import extract_array
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import base
Expand Down Expand Up @@ -789,22 +787,11 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
dtype = obj.dtype

if not is_scalar(result):
if is_datetime64tz_dtype(dtype):
# GH 23683
# Prior results _may_ have been generated in UTC.
# Ensure we localize to UTC first before converting
# to the target timezone
arr = extract_array(obj)
try:
result = arr._from_sequence(result, dtype="datetime64[ns, UTC]")
result = result.astype(dtype)
except TypeError:
# _try_cast was called at a point where the result
# was already tz-aware
pass
elif is_extension_array_dtype(dtype):
if is_extension_array_dtype(dtype) and dtype.kind != "M":
# The function can return something of any type, so check
# if the type is compatible with the calling EA.
# if the type is compatible with the calling EA.
# datetime64tz is handled correctly in agg_series,
# so is excluded here.

# return the same type (Series) as our caller
cls = dtype.construct_array_type()
Expand Down Expand Up @@ -869,7 +856,9 @@ def _cython_agg_general(
if numeric_only and not is_numeric:
continue

result, names = self.grouper.aggregate(obj.values, how, min_count=min_count)
result, names = self.grouper.aggregate(
obj._values, how, min_count=min_count
)
output[name] = self._try_cast(result, obj)

if len(output) == 0:
Expand Down
12 changes: 9 additions & 3 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,11 @@ def agg_series(self, obj: Series, func):
# SeriesGrouper would raise if we were to call _aggregate_series_fast
return self._aggregate_series_pure_python(obj, func)

elif is_extension_array_dtype(obj.dtype) and obj.dtype.kind != "M":
elif is_extension_array_dtype(obj.dtype):
# _aggregate_series_fast would raise TypeError when
# calling libreduction.Slider
# In the datetime64tz case it would incorrectly cast to tz-naive
# TODO: can we get a performant workaround for EAs backed by ndarray?
# TODO: is the datetime64tz case supposed to go through here?
return self._aggregate_series_pure_python(obj, func)

elif isinstance(obj.index, MultiIndex):
Expand Down Expand Up @@ -657,7 +657,13 @@ def _aggregate_series_pure_python(self, obj: Series, func):
res = func(group)
if result is None:
if isinstance(res, (Series, Index, np.ndarray)):
raise ValueError("Function does not reduce")
if len(res) == 1:
# e.g. test_agg_lambda_with_timezone lambda e: e.head(1)
# FIXME: are we potentially losing import res.index info?
res = getattr(res, "_values", res)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just do
res = np.array(res)[0]? or
res = res.item() (though I think we have deprecated .item(), but are bringing it back.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.array won't work because itll lose the timezone. im thinking next(iter(res)) with a comment saying its often equivalent to res[0]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think .item() would be good here as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.item() will be good if/when we un-deprecate it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the current next(iter(res)) acceptable for the time being? extract_array(res)[0] would also work.

I'm eager to see this go in because I think we can get rid of a bunch more _try_cast calls

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep just mark it with the issue (or fix me) so it’s clear

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented

res = res[0]
else:
raise ValueError("Function does not reduce")
result = np.empty(ngroups, dtype="O")

counts[label] = group.shape[0]
Expand Down
20 changes: 20 additions & 0 deletions pandas/tests/groupby/aggregate/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,26 @@ def test_agg_over_numpy_arrays():
tm.assert_frame_equal(result, expected)


def test_agg_tzaware_non_datetime_result():
# discussed in GH#29589, fixed in GH#29641, operating on tzaware values
# with function that is not dtype-preserving
dti = pd.date_range("2012-01-01", periods=4, tz="UTC")
df = pd.DataFrame({"a": [0, 0, 1, 1], "b": dti})
gb = df.groupby("a")

# Case that _does_ preserve the dtype
result = gb["b"].agg(lambda x: x.iloc[0])
expected = pd.Series(dti[::2], name="b")
expected.index.name = "a"
tm.assert_series_equal(result, expected)

# Case that does _not_ preserve the dtype
result = gb["b"].agg(lambda x: x.iloc[0].year)
expected = pd.Series([2012, 2012], name="b")
expected.index.name = "a"
tm.assert_series_equal(result, expected)


def test_agg_timezone_round_trip():
# GH 15426
ts = pd.Timestamp("2016-01-01 12:00:00", tz="US/Pacific")
Expand Down