diff --git a/pandas/compat/__init__.py b/pandas/compat/__init__.py index 052eb7792a19c..60a9b3d4fd30e 100644 --- a/pandas/compat/__init__.py +++ b/pandas/compat/__init__.py @@ -30,6 +30,7 @@ pa_version_under7p0, pa_version_under8p0, pa_version_under9p0, + pa_version_under11p0, ) @@ -159,6 +160,7 @@ def get_lzma_file() -> type[pandas.compat.compressors.LZMAFile]: "pa_version_under7p0", "pa_version_under8p0", "pa_version_under9p0", + "pa_version_under11p0", "IS64", "PY39", "PY310", diff --git a/pandas/compat/pyarrow.py b/pandas/compat/pyarrow.py index ea8e18437fcfb..020ec346490ff 100644 --- a/pandas/compat/pyarrow.py +++ b/pandas/compat/pyarrow.py @@ -13,8 +13,10 @@ pa_version_under8p0 = _palv < Version("8.0.0") pa_version_under9p0 = _palv < Version("9.0.0") pa_version_under10p0 = _palv < Version("10.0.0") + pa_version_under11p0 = _palv < Version("11.0.0") except ImportError: pa_version_under7p0 = True pa_version_under8p0 = True pa_version_under9p0 = True pa_version_under10p0 = True + pa_version_under11p0 = True diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index a4cde823c6713..7c98c2804306e 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -29,6 +29,7 @@ pa_version_under7p0, pa_version_under8p0, pa_version_under9p0, + pa_version_under11p0, ) from pandas.util._decorators import doc from pandas.util._validators import validate_fillna_kwargs @@ -130,6 +131,16 @@ def floordiv_compat( ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray") +def get_unit_from_pa_dtype(pa_dtype): + # https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804 + if pa_version_under11p0: + unit = str(pa_dtype).split("[", 1)[-1][:-1] + if unit not in ["s", "ms", "us", "ns"]: + raise ValueError(pa_dtype) + return unit + return pa_dtype.unit + + def to_pyarrow_type( dtype: ArrowDtype | pa.DataType | Dtype | None, ) -> pa.DataType | None: @@ -1039,6 +1050,13 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): elif name in ["min", "max", "sum"] and pa.types.is_duration(pa_type): data_to_reduce = self._data.cast(pa.int64()) + elif name in ["median", "mean", "std", "sem"] and pa.types.is_temporal(pa_type): + nbits = pa_type.bit_width + if nbits == 32: + data_to_reduce = self._data.cast(pa.int32()) + else: + data_to_reduce = self._data.cast(pa.int64()) + if name == "sem": def pyarrow_meth(data, skip_nulls, **kwargs): @@ -1076,6 +1094,22 @@ def pyarrow_meth(data, skip_nulls, **kwargs): if name in ["min", "max", "sum"] and pa.types.is_duration(pa_type): result = result.cast(pa_type) + if name in ["median", "mean"] and pa.types.is_temporal(pa_type): + result = result.cast(pa_type) + if name in ["std", "sem"] and pa.types.is_temporal(pa_type): + result = result.cast(pa.int64()) + if pa.types.is_duration(pa_type): + result = result.cast(pa_type) + elif pa.types.is_time(pa_type): + unit = get_unit_from_pa_dtype(pa_type) + result = result.cast(pa.duration(unit)) + elif pa.types.is_date(pa_type): + # go with closest available unit, i.e. "s" + result = result.cast(pa.duration("s")) + else: + # i.e. timestamp + result = result.cast(pa.duration(pa_type.unit)) + return result.as_py() def __setitem__(self, key, value) -> None: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 522a0d59e4161..84d82829d00e7 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -513,13 +513,6 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request): elif all_numeric_reductions == "sem" and pa_version_under8p0: request.node.add_marker(xfail_mark) - elif all_numeric_reductions in [ - "mean", - "median", - "std", - "sem", - ] and pa.types.is_temporal(pa_dtype): - request.node.add_marker(xfail_mark) elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in { "sem", "std",