From 99073b86200248f6590f6806976bcffb248f79af Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 1 Dec 2020 07:59:48 -0800 Subject: [PATCH 01/12] REF: consolidate casting --- pandas/core/groupby/groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 422cf78bc927d..dfc1e7943c76a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1188,7 +1188,6 @@ def _python_agg_general(self, func, *args, **kwargs): if self.grouper._filter_empty_groups: mask = counts.ravel() > 0 - # since we are masking, make sure that we have a float object values = result if is_numeric_dtype(values.dtype): From 1ed9d8acf4889ef3471dcba08bf3953cd1825b3c Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 1 Dec 2020 08:24:35 -0800 Subject: [PATCH 02/12] lighter-weight casting --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index dfc1e7943c76a..95744e4549b5f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -50,7 +50,7 @@ class providing the base-class of operations. from pandas.errors import AbstractMethodError from pandas.util._decorators import Appender, Substitution, cache_readonly, doc -from pandas.core.dtypes.cast import maybe_cast_result, maybe_downcast_to_dtype +from pandas.core.dtypes.cast import maybe_downcast_to_dtype from pandas.core.dtypes.common import ( ensure_float, is_bool_dtype, From 21b4f0fc39ec771bf5e796f21205552191efb393 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 1 Dec 2020 09:17:40 -0800 Subject: [PATCH 03/12] simplify casting --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 95744e4549b5f..5bfe3ef9c7f9d 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1188,7 +1188,7 @@ def _python_agg_general(self, func, *args, **kwargs): if self.grouper._filter_empty_groups: mask = counts.ravel() > 0 - # since we are masking, make sure that we have a float object + # since we are masking, make sure that we have a float dtype values = result if is_numeric_dtype(values.dtype): values = ensure_float(values) From 56b42bb613bb62dd2cf946635020d51affaff75d Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 1 Dec 2020 19:29:18 -0800 Subject: [PATCH 04/12] REF: minimize python_agg_general groupby casting --- pandas/core/groupby/ops.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 8046be669ea51..93d47d6767312 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -47,6 +47,7 @@ is_extension_array_dtype, is_integer_dtype, is_numeric_dtype, + is_object_dtype, is_period_dtype, is_sparse, is_timedelta64_dtype, @@ -725,7 +726,15 @@ def _aggregate_series_pure_python(self, obj: Series, func: F): result[label] = res result = lib.maybe_convert_objects(result, try_float=0) - result = maybe_cast_result(result, obj, numeric_only=True) + + if is_object_dtype(result.dtype) and is_extension_array_dtype(obj.dtype): + # FIXME: kludge; we have tests for DecimalArray that get here + # but they only work because DecimalArray._from_sequence is + # strict in what inputs it accepts, which we cannot rely on. + inferred = lib.infer_dtype(result, skipna=False) + if inferred == obj.dtype.name == "decimal": + cls = obj.dtype.construct_array_type() + result = cls._from_sequence(result) return result, counts From e01487fa7fa75b45c93a141d5a4e53fe22bc84dd Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 2 Dec 2020 17:24:25 -0800 Subject: [PATCH 05/12] Handle Float64 --- pandas/core/groupby/ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 93d47d6767312..9ca5c23d23fea 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -727,7 +727,11 @@ def _aggregate_series_pure_python(self, obj: Series, func: F): result = lib.maybe_convert_objects(result, try_float=0) - if is_object_dtype(result.dtype) and is_extension_array_dtype(obj.dtype): + if is_numeric_dtype(obj.dtype): + # Needed to cast float64 back to Float64 + result = maybe_cast_result(result, obj, numeric_only=True) + + elif is_object_dtype(result.dtype) and is_extension_array_dtype(obj.dtype): # FIXME: kludge; we have tests for DecimalArray that get here # but they only work because DecimalArray._from_sequence is # strict in what inputs it accepts, which we cannot rely on. From cf5bd535fba04cd53158e93e5b942486b0d0624f Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 2 Dec 2020 17:30:49 -0800 Subject: [PATCH 06/12] TST: PeriodDtype --- pandas/tests/groupby/aggregate/test_other.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pandas/tests/groupby/aggregate/test_other.py b/pandas/tests/groupby/aggregate/test_other.py index 5d0f6d6262899..6e0e091c5db8b 100644 --- a/pandas/tests/groupby/aggregate/test_other.py +++ b/pandas/tests/groupby/aggregate/test_other.py @@ -432,10 +432,13 @@ def test_agg_over_numpy_arrays(): tm.assert_frame_equal(result, expected) -def test_agg_tzaware_non_datetime_result(): +@pytest.mark.parametrize("as_period", [True, False]) +def test_agg_tzaware_non_datetime_result(as_period): # discussed in GH#29589, fixed in GH#29641, operating on tzaware values # with function that is not dtype-preserving dti = pd.date_range("2012-01-01", periods=4, tz="UTC") + if as_period: + dti = dti.tz_localize(None).to_period("D") df = DataFrame({"a": [0, 0, 1, 1], "b": dti}) gb = df.groupby("a") @@ -454,6 +457,8 @@ def test_agg_tzaware_non_datetime_result(): result = gb["b"].agg(lambda x: x.iloc[-1] - x.iloc[0]) expected = Series([pd.Timedelta(days=1), pd.Timedelta(days=1)], name="b") expected.index.name = "a" + if as_period: + expected = expected.astype(object) tm.assert_series_equal(result, expected) From f950c750bed296fad901c1a7cf71195a8274e9d4 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 3 Dec 2020 17:27:27 -0800 Subject: [PATCH 07/12] dont cast --- pandas/core/groupby/groupby.py | 3 ++- pandas/core/groupby/ops.py | 15 +-------------- pandas/tests/arrays/floating/test_function.py | 3 ++- pandas/tests/extension/decimal/test_decimal.py | 15 +++++++++++---- pandas/tests/groupby/aggregate/test_other.py | 3 ++- 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 3665a83b64ed6..652fda85c4b8e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1173,7 +1173,8 @@ def _python_agg_general(self, func, *args, **kwargs): if self.grouper._filter_empty_groups: mask = counts.ravel() > 0 - # since we are masking, make sure that we have a float dtype + + # since we are masking, make sure that we have a float object values = result if is_numeric_dtype(values.dtype): values = ensure_float(values) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index d2fb194754218..45212de7ea655 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -47,7 +47,6 @@ is_extension_array_dtype, is_integer_dtype, is_numeric_dtype, - is_object_dtype, is_period_dtype, is_sparse, is_timedelta64_dtype, @@ -719,19 +718,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F): result[label] = res result = lib.maybe_convert_objects(result, try_float=0) - - if is_numeric_dtype(obj.dtype): - # Needed to cast float64 back to Float64 - result = maybe_cast_result(result, obj, numeric_only=True) - - elif is_object_dtype(result.dtype) and is_extension_array_dtype(obj.dtype): - # FIXME: kludge; we have tests for DecimalArray that get here - # but they only work because DecimalArray._from_sequence is - # strict in what inputs it accepts, which we cannot rely on. - inferred = lib.infer_dtype(result, skipna=False) - if inferred == obj.dtype.name == "decimal": - cls = obj.dtype.construct_array_type() - result = cls._from_sequence(result) + # TODO: cast to EA once _from_sequence is reliably strict GH#38254 return result, counts diff --git a/pandas/tests/arrays/floating/test_function.py b/pandas/tests/arrays/floating/test_function.py index ef95eac316397..25354c377b494 100644 --- a/pandas/tests/arrays/floating/test_function.py +++ b/pandas/tests/arrays/floating/test_function.py @@ -157,8 +157,9 @@ def test_preserve_dtypes(op): # groupby result = getattr(df.groupby("A"), op)() + # GH#38254 until _from_sequence is reliably strict, we cannot retain Float64 expected = pd.DataFrame( - {"B": np.array([1.0, 3.0]), "C": pd.array([0.1, 3], dtype="Float64")}, + {"B": np.array([1.0, 3.0]), "C": pd.array([0.1, 3], dtype="float64")}, index=pd.Index(["a", "b"], name="A"), ) tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 233b658d29782..bfc6b613a8712 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -445,7 +445,8 @@ def test_groupby_agg(): ) # single key, selected column - expected = pd.Series(to_decimal([data[0], data[3]])) + # GH#38254 until _from_sequence is reliably strict, we cant retain dtype + expected = pd.Series(to_decimal([data[0], data[3]])).astype(object) result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0]) tm.assert_series_equal(result, expected, check_names=False) result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0]) @@ -455,14 +456,16 @@ def test_groupby_agg(): expected = pd.Series( to_decimal([data[0], data[1], data[3]]), index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]), - ) + ).astype(object) result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0]) tm.assert_series_equal(result, expected, check_names=False) result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0]) tm.assert_series_equal(result, expected, check_names=False) # multiple columns - expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])}) + expected = pd.DataFrame( + {"id2": [0, 1], "decimals": to_decimal([data[0], data[3]]).astype(object)} + ) result = df.groupby("id1").agg(lambda x: x.iloc[0]) tm.assert_frame_equal(result, expected, check_names=False) @@ -478,7 +481,11 @@ def DecimalArray__my_sum(self): data = make_data()[:5] df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)}) - expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]])) + + # GH#38254 until _from_sequence is reliably strict, we cant retain dtype + expected = pd.Series( + to_decimal([data[0] + data[1] + data[2], data[3] + data[4]]) + ).astype(object) result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum()) tm.assert_series_equal(result, expected, check_names=False) diff --git a/pandas/tests/groupby/aggregate/test_other.py b/pandas/tests/groupby/aggregate/test_other.py index 6e0e091c5db8b..16a7bcae9f928 100644 --- a/pandas/tests/groupby/aggregate/test_other.py +++ b/pandas/tests/groupby/aggregate/test_other.py @@ -632,7 +632,8 @@ def test_groupby_agg_err_catching(err_cls): {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)} ) - expected = Series(to_decimal([data[0], data[3]])) + # GH#38254 until _from_sequence is strict, we cannot reliably cast agg results + expected = Series(to_decimal([data[0], data[3]])).astype(object) def weird_func(x): # weird function that raise something other than TypeError or IndexError From 5a30b93fa4f699d949179f46e8351bd88c15504e Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 4 Dec 2020 07:34:47 -0800 Subject: [PATCH 08/12] whatsnew --- doc/source/whatsnew/v1.2.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 8182dfa4bce40..483aa56d980d4 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -773,6 +773,7 @@ Groupby/resample/rolling - Bug in :meth:`.DataFrameGroupBy.transform` would raise when used with ``axis=1`` and a transformation kernel (e.g. "shift") (:issue:`36308`) - Bug in :meth:`.DataFrameGroupBy.quantile` couldn't handle with arraylike ``q`` when grouping by columns (:issue:`33795`) - Bug in :meth:`DataFrameGroupBy.rank` with ``datetime64tz`` or period dtype incorrectly casting results to those dtypes instead of returning ``float64`` dtype (:issue:`38187`) +- Bug in :meth:`DataFrameGroupBy.agg` and :meth:`SeriesGroupBy.agg` with ``ExtensionDtype`` columns incorrectly casting results too aggressively (:issue:`38254`) Reshaping ^^^^^^^^^ From b4810492b3236581cc3478032efffbf6a8a56f65 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 8 Dec 2020 08:49:01 -0800 Subject: [PATCH 09/12] move whatsnew to 1.3.0 --- doc/source/whatsnew/v1.2.0.rst | 1 - doc/source/whatsnew/v1.3.0.rst | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index d08fe1d3d16bb..4294871b56bcb 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -794,7 +794,6 @@ Groupby/resample/rolling - Bug in :meth:`.DataFrameGroupBy.apply` dropped values on ``nan`` group when returning the same axes with the original frame (:issue:`38227`) - Bug in :meth:`.DataFrameGroupBy.quantile` couldn't handle with arraylike ``q`` when grouping by columns (:issue:`33795`) - Bug in :meth:`DataFrameGroupBy.rank` with ``datetime64tz`` or period dtype incorrectly casting results to those dtypes instead of returning ``float64`` dtype (:issue:`38187`) -- Bug in :meth:`DataFrameGroupBy.agg` and :meth:`SeriesGroupBy.agg` with ``ExtensionDtype`` columns incorrectly casting results too aggressively (:issue:`38254`) Reshaping ^^^^^^^^^ diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index b40f012f034b6..0c60c6f9c409e 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -163,7 +163,7 @@ Plotting Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ - +- Bug in :meth:`DataFrameGroupBy.agg` and :meth:`SeriesGroupBy.agg` with ``ExtensionDtype`` columns incorrectly casting results too aggressively (:issue:`38254`) - - From 2047f3cacb2c96a1d531d3896643528c8da9378d Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 8 Dec 2020 09:17:23 -0800 Subject: [PATCH 10/12] CLN: remove unused import --- pandas/core/groupby/ops.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 22519ca57904d..dd42bb8a2eb51 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -28,11 +28,7 @@ from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly -from pandas.core.dtypes.cast import ( - maybe_cast_result, - maybe_cast_result_dtype, - maybe_downcast_to_dtype, -) +from pandas.core.dtypes.cast import maybe_cast_result_dtype, maybe_downcast_to_dtype from pandas.core.dtypes.common import ( ensure_float, ensure_float64, From 7bc78eb592732eb568badf3bbc13933f89836fd8 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 8 Dec 2020 13:00:56 -0800 Subject: [PATCH 11/12] retain Float64 --- pandas/tests/arrays/floating/test_function.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/tests/arrays/floating/test_function.py b/pandas/tests/arrays/floating/test_function.py index 25354c377b494..ef95eac316397 100644 --- a/pandas/tests/arrays/floating/test_function.py +++ b/pandas/tests/arrays/floating/test_function.py @@ -157,9 +157,8 @@ def test_preserve_dtypes(op): # groupby result = getattr(df.groupby("A"), op)() - # GH#38254 until _from_sequence is reliably strict, we cannot retain Float64 expected = pd.DataFrame( - {"B": np.array([1.0, 3.0]), "C": pd.array([0.1, 3], dtype="float64")}, + {"B": np.array([1.0, 3.0]), "C": pd.array([0.1, 3], dtype="Float64")}, index=pd.Index(["a", "b"], name="A"), ) tm.assert_frame_equal(result, expected) From eddb08911f60dc503e5e2614c593508b097d087a Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 11 Dec 2020 10:11:41 -0800 Subject: [PATCH 12/12] fix test --- pandas/tests/groupby/aggregate/test_other.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/tests/groupby/aggregate/test_other.py b/pandas/tests/groupby/aggregate/test_other.py index 16a7bcae9f928..67b99678ebec5 100644 --- a/pandas/tests/groupby/aggregate/test_other.py +++ b/pandas/tests/groupby/aggregate/test_other.py @@ -458,7 +458,8 @@ def test_agg_tzaware_non_datetime_result(as_period): expected = Series([pd.Timedelta(days=1), pd.Timedelta(days=1)], name="b") expected.index.name = "a" if as_period: - expected = expected.astype(object) + expected = Series([pd.offsets.Day(1), pd.offsets.Day(1)], name="b") + expected.index.name = "a" tm.assert_series_equal(result, expected)