From 91ac7bf515b7c88e8a27520398bee01f91a0ce88 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 25 Jan 2023 17:07:23 -0800 Subject: [PATCH] BUG/ENH: fix pyarrow quantile xfails --- pandas/core/arrays/arrow/array.py | 9 +++++++-- pandas/tests/extension/test_arrow.py | 24 +++++++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 0e70b3795bc85..425346b5a6983 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1243,7 +1243,7 @@ def _quantile( pa_dtype = self._data.type data = self._data - if pa.types.is_temporal(pa_dtype) and interpolation in ["lower", "higher"]: + if pa.types.is_temporal(pa_dtype): # https://github.com/apache/arrow/issues/33769 in these cases # we can cast to ints and back nbits = pa_dtype.bit_width @@ -1254,7 +1254,12 @@ def _quantile( result = pc.quantile(data, q=qs, interpolation=interpolation) - if pa.types.is_temporal(pa_dtype) and interpolation in ["lower", "higher"]: + if pa.types.is_temporal(pa_dtype): + nbits = pa_dtype.bit_width + if nbits == 32: + result = result.cast(pa.int32()) + else: + result = result.cast(pa.int64()) result = result.cast(pa_dtype) return type(self)(result) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 9db49470edaf2..2244346a03924 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1325,7 +1325,7 @@ def test_quantile(data, interpolation, quantile, request): if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): pass - elif pa.types.is_temporal(data._data.type) and interpolation in ["lower", "higher"]: + elif pa.types.is_temporal(data._data.type): pass else: request.node.add_marker( @@ -1337,6 +1337,28 @@ def test_quantile(data, interpolation, quantile, request): data = data.take([0, 0, 0]) ser = pd.Series(data) result = ser.quantile(q=quantile, interpolation=interpolation) + + if pa.types.is_timestamp(pa_dtype) and interpolation not in ["lower", "higher"]: + # rounding error will make the check below fail + # (e.g. '2020-01-01 01:01:01.000001' vs '2020-01-01 01:01:01.000001024'), + # so we'll check for now that we match the numpy analogue + if pa_dtype.tz: + pd_dtype = f"M8[{pa_dtype.unit}, {pa_dtype.tz}]" + else: + pd_dtype = f"M8[{pa_dtype.unit}]" + ser_np = ser.astype(pd_dtype) + + expected = ser_np.quantile(q=quantile, interpolation=interpolation) + if quantile == 0.5: + if pa_dtype.unit == "us": + expected = expected.to_pydatetime(warn=False) + assert result == expected + else: + if pa_dtype.unit == "us": + expected = expected.dt.floor("us") + tm.assert_series_equal(result, expected.astype(data.dtype)) + return + if quantile == 0.5: assert result == data[0] else: