From c65153a20e30568e22f3dc54225208dd8d1eeae3 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Jan 2023 11:45:24 -0800 Subject: [PATCH 1/3] ENH: support reductions for pyarrow temporal types --- pandas/core/arrays/arrow/array.py | 25 +++++++++++++++++++++++++ pandas/tests/extension/test_arrow.py | 7 ------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index e4b4919c63679..0a7cc13b9de4a 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1046,6 +1046,13 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): elif name in ["min", "max", "sum"] and pa.types.is_duration(pa_type): data_to_reduce = self._data.cast(pa.int64()) + elif name in ["median", "mean", "std", "sem"] and pa.types.is_temporal(pa_type): + nbits = pa_type.bit_width + if nbits == 32: + data_to_reduce = self._data.cast(pa.int32()) + else: + data_to_reduce = self._data.cast(pa.int64()) + if name == "sem": def pyarrow_meth(data, skip_nulls, **kwargs): @@ -1083,6 +1090,24 @@ def pyarrow_meth(data, skip_nulls, **kwargs): if name in ["min", "max", "sum"] and pa.types.is_duration(pa_type): result = result.cast(pa_type) + if name in ["median", "mean"] and pa.types.is_temporal(pa_type): + result = result.cast(pa_type) + if name in ["std", "sem"] and pa.types.is_temporal(pa_type): + result = result.cast(pa.int64()) + if pa.types.is_duration(pa_type): + result = result.cast(pa_type) + elif pa.types.is_time(pa_type): + # TODO: with time types we should probably retain "unit", + # but not clear how to get that since pa_type.unit raises + # AttributeError + result = result.cast(pa.duration("ns")) + elif pa.types.is_date(pa_type): + # go with closest available unit, i.e. "s" + result = result.cast(pa.duration("s")) + else: + # i.e. timestamp + result = result.cast(pa.duration(pa_type.unit)) + return result.as_py() def __setitem__(self, key, value) -> None: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index e51b0aa318582..a202c48b0b5f9 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -542,13 +542,6 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): ) ) - elif all_numeric_reductions in [ - "mean", - "median", - "std", - "sem", - ] and pa.types.is_temporal(pa_dtype): - request.node.add_marker(xfail_mark) elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in { "sem", "std", From 49b0db0c70a33b312b02f4c97817ca810f13a98c Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 8 Feb 2023 08:27:41 -0800 Subject: [PATCH 2/3] unit check --- pandas/compat/__init__.py | 2 ++ pandas/compat/pyarrow.py | 2 ++ pandas/core/arrays/arrow/array.py | 17 +++++++++++++---- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pandas/compat/__init__.py b/pandas/compat/__init__.py index 052eb7792a19c..60a9b3d4fd30e 100644 --- a/pandas/compat/__init__.py +++ b/pandas/compat/__init__.py @@ -30,6 +30,7 @@ pa_version_under7p0, pa_version_under8p0, pa_version_under9p0, + pa_version_under11p0, ) @@ -159,6 +160,7 @@ def get_lzma_file() -> type[pandas.compat.compressors.LZMAFile]: "pa_version_under7p0", "pa_version_under8p0", "pa_version_under9p0", + "pa_version_under11p0", "IS64", "PY39", "PY310", diff --git a/pandas/compat/pyarrow.py b/pandas/compat/pyarrow.py index ea8e18437fcfb..020ec346490ff 100644 --- a/pandas/compat/pyarrow.py +++ b/pandas/compat/pyarrow.py @@ -13,8 +13,10 @@ pa_version_under8p0 = _palv < Version("8.0.0") pa_version_under9p0 = _palv < Version("9.0.0") pa_version_under10p0 = _palv < Version("10.0.0") + pa_version_under11p0 = _palv < Version("11.0.0") except ImportError: pa_version_under7p0 = True pa_version_under8p0 = True pa_version_under9p0 = True pa_version_under10p0 = True + pa_version_under11p0 = True diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 3693c095ca215..032a3d213500a 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -29,6 +29,7 @@ pa_version_under7p0, pa_version_under8p0, pa_version_under9p0, + pa_version_under11p0, ) from pandas.util._decorators import doc from pandas.util._validators import validate_fillna_kwargs @@ -130,6 +131,16 @@ def floordiv_compat( ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray") +def get_unit_from_pa_dtype(pa_dtype): + # https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804 + if pa_version_under11p0: + unit = str(pa_dtype).split("[")[-1][:-1] + if unit not in ["s", "ms", "us", "ns"]: + raise ValueError(pa_dtype) + return unit + return pa_dtype.unit + + def to_pyarrow_type( dtype: ArrowDtype | pa.DataType | Dtype | None, ) -> pa.DataType | None: @@ -1090,10 +1101,8 @@ def pyarrow_meth(data, skip_nulls, **kwargs): if pa.types.is_duration(pa_type): result = result.cast(pa_type) elif pa.types.is_time(pa_type): - # TODO: with time types we should probably retain "unit", - # but not clear how to get that since pa_type.unit raises - # AttributeError - result = result.cast(pa.duration("ns")) + unit = get_unit_from_pa_dtype(pa_type) + result = result.cast(pa.duration(unit)) elif pa.types.is_date(pa_type): # go with closest available unit, i.e. "s" result = result.cast(pa.duration("s")) From 715f69b0439519d3dc23a7ab34209539d2ce2406 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 8 Feb 2023 09:11:37 -0800 Subject: [PATCH 3/3] lint fixup --- pandas/core/arrays/arrow/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 032a3d213500a..7c98c2804306e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -134,7 +134,7 @@ def floordiv_compat( def get_unit_from_pa_dtype(pa_dtype): # https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804 if pa_version_under11p0: - unit = str(pa_dtype).split("[")[-1][:-1] + unit = str(pa_dtype).split("[", 1)[-1][:-1] if unit not in ["s", "ms", "us", "ns"]: raise ValueError(pa_dtype) return unit