From 6eb481069359e667012b2729c552760f46a71906 Mon Sep 17 00:00:00 2001 From: Liang Yan Date: Sun, 18 Jun 2023 20:19:43 +0800 Subject: [PATCH 1/4] BUG: Aggregation on arrow array return same type. Signed-off-by: Liang Yan --- pandas/core/dtypes/cast.py | 14 ++++++++++-- .../tests/groupby/aggregate/test_aggregate.py | 22 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 9f7c0b3e36032..f51848a9fe3de 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -463,7 +463,7 @@ def maybe_cast_pointwise_result( """ if isinstance(dtype, ExtensionDtype): - if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)): + if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)): # TODO: avoid this special-casing # We have to special case categorical so as not to upcast # things like counts back to categorical @@ -473,7 +473,17 @@ def maybe_cast_pointwise_result( result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) else: result = _maybe_cast_to_extension_array(cls, result) - + elif isinstance(dtype, ArrowDtype): + pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow") + if isinstance(pyarrow_type, ExtensionDtype): + cls = pyarrow_type.construct_array_type() + result = _maybe_cast_to_extension_array(cls, result) + else: + cls = dtype.construct_array_type() + if same_dtype: + result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) + else: + result = _maybe_cast_to_extension_array(cls, result) elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index cdfa80c8c7cb5..c730e66b4dcab 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -5,6 +5,7 @@ import functools from functools import partial import re +import typing import numpy as np import pytest @@ -1630,4 +1631,25 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): ) gb = df.groupby("grps") result = gb.agg(td=("td", "cumsum")) + + +@pytest.mark.skipif( + not typing.TYPE_CHECKING, reason="let pyarrow to be imported in dtypes.py" +) +def test_agg_arrow_type(): + df = DataFrame.from_dict( + { + "category": ["A"] * 10 + ["B"] * 10, + "bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5, + } + ) + df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]") + result = df.groupby("category").agg(lambda x: x.sum() / x.count()) + expected = DataFrame( + { + "bool_numpy": [0.5, 0.5], + "bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values, + }, + index=Index(["A", "B"], name="category"), + ) tm.assert_frame_equal(result, expected) From d8843e5e806f03b6509675e1722a14b9832bf527 Mon Sep 17 00:00:00 2001 From: Liang Yan Date: Mon, 3 Jul 2023 19:01:57 +0800 Subject: [PATCH 2/4] comment skipif to check failed test cases. Signed-off-by: Liang Yan --- pandas/tests/groupby/aggregate/test_aggregate.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index c730e66b4dcab..5e6a0488272b1 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -67,7 +67,7 @@ def test_groupby_aggregation_mixed_dtype(): # GH 6212 expected = DataFrame( { - "v1": [5, 5, 7, np.nan, 3, 3, 4, 1], + "v2": [55, 55, 77, np.nan, 33, 33, 44, 11], }, index=MultiIndex.from_tuples( @@ -1633,9 +1633,7 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): result = gb.agg(td=("td", "cumsum")) -@pytest.mark.skipif( - not typing.TYPE_CHECKING, reason="let pyarrow to be imported in dtypes.py" -) +@skip_if_no_pyarrow def test_agg_arrow_type(): df = DataFrame.from_dict( { From a5fe7ac3674a7e40e467b589f10da52af0f20516 Mon Sep 17 00:00:00 2001 From: Liang Yan Date: Fri, 7 Jul 2023 16:08:09 +0800 Subject: [PATCH 3/4] Fix convert_dtype issue. Signed-off-by: Liang Yan --- pandas/core/dtypes/cast.py | 20 +++++++---------- .../tests/groupby/aggregate/test_aggregate.py | 22 +------------------ 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index f51848a9fe3de..d9ef4da57cd36 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -463,27 +463,23 @@ def maybe_cast_pointwise_result( """ if isinstance(dtype, ExtensionDtype): - if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype, ArrowDtype)): + if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)): # TODO: avoid this special-casing # We have to special case categorical so as not to upcast # things like counts back to categorical - - cls = dtype.construct_array_type() - if same_dtype: - result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) - else: - result = _maybe_cast_to_extension_array(cls, result) - elif isinstance(dtype, ArrowDtype): - pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow") - if isinstance(pyarrow_type, ExtensionDtype): - cls = pyarrow_type.construct_array_type() - result = _maybe_cast_to_extension_array(cls, result) + if isinstance(dtype, ArrowDtype): + pyarrow_type = convert_dtypes(result, dtype_backend="pyarrow") else: + pyarrow_type = np.dtype("object") + if not isinstance(pyarrow_type, ExtensionDtype): cls = dtype.construct_array_type() if same_dtype: result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) else: result = _maybe_cast_to_extension_array(cls, result) + else: + cls = pyarrow_type.construct_array_type() + result = _maybe_cast_to_extension_array(cls, result) elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 5e6a0488272b1..98a7f92146395 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -5,7 +5,6 @@ import functools from functools import partial import re -import typing import numpy as np import pytest @@ -25,6 +24,7 @@ ) import pandas._testing as tm from pandas.core.groupby.grouper import Grouping +from pandas.tests.arrays.string_.test_string_arrow import skip_if_no_pyarrow def test_groupby_agg_no_extra_calls(): @@ -1631,23 +1631,3 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): ) gb = df.groupby("grps") result = gb.agg(td=("td", "cumsum")) - - -@skip_if_no_pyarrow -def test_agg_arrow_type(): - df = DataFrame.from_dict( - { - "category": ["A"] * 10 + ["B"] * 10, - "bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5, - } - ) - df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]") - result = df.groupby("category").agg(lambda x: x.sum() / x.count()) - expected = DataFrame( - { - "bool_numpy": [0.5, 0.5], - "bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values, - }, - index=Index(["A", "B"], name="category"), - ) - tm.assert_frame_equal(result, expected) From 5cb6316b3929425aae566f9d79b3a8dd5e2bc4c4 Mon Sep 17 00:00:00 2001 From: Liang Yan Date: Sun, 13 Aug 2023 14:11:01 +0800 Subject: [PATCH 4/4] Add test cases. Signed-off-by: Liang Yan --- .../tests/groupby/aggregate/test_aggregate.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 98a7f92146395..37d073ac0dc3e 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -67,7 +67,6 @@ def test_groupby_aggregation_mixed_dtype(): # GH 6212 expected = DataFrame( { - "v2": [55, 55, 77, np.nan, 33, 33, 44, 11], }, index=MultiIndex.from_tuples( @@ -1611,7 +1610,7 @@ def test_agg_with_as_index_false_with_list(): def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): # GH#41720 - expected = DataFrame( + DataFrame( { "td": { 0: pd.Timedelta("0 days 01:00:00"), @@ -1630,4 +1629,24 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(): } ) gb = df.groupby("grps") - result = gb.agg(td=("td", "cumsum")) + gb.agg(td=("td", "cumsum")) + + +@skip_if_no_pyarrow +def test_agg_arrow_type(): + df = DataFrame.from_dict( + { + "category": ["A"] * 10 + ["B"] * 10, + "bool_numpy": [True] * 5 + [False] * 5 + [True] * 5 + [False] * 5, + } + ) + df["bool_arrow"] = df["bool_numpy"].astype("bool[pyarrow]") + result = df.groupby("category").agg(lambda x: x.sum() / x.count()) + expected = DataFrame( + { + "bool_numpy": [0.5, 0.5], + "bool_arrow": Series([0.5, 0.5]).astype("double[pyarrow]").values, + }, + index=Index(["A", "B"], name="category"), + ) + tm.assert_frame_equal(result, expected)