Skip to content

Commit 2288cb7

Browse files
jbrockmendelpooja-subramaniam
authored andcommitted
ENH: pyarrow temporal dtypes support quantile in some cases (pandas-dev#50868)
1 parent bdfaeac commit 2288cb7

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

pandas/core/arrays/arrow/array.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,23 @@ def _quantile(
12081208
-------
12091209
same type as self
12101210
"""
1211-
result = pc.quantile(self._data, q=qs, interpolation=interpolation)
1211+
pa_dtype = self._data.type
1212+
1213+
data = self._data
1214+
if pa.types.is_temporal(pa_dtype) and interpolation in ["lower", "higher"]:
1215+
# https://github.com/apache/arrow/issues/33769 in these cases
1216+
# we can cast to ints and back
1217+
nbits = pa_dtype.bit_width
1218+
if nbits == 32:
1219+
data = data.cast(pa.int32())
1220+
else:
1221+
data = data.cast(pa.int64())
1222+
1223+
result = pc.quantile(data, q=qs, interpolation=interpolation)
1224+
1225+
if pa.types.is_temporal(pa_dtype) and interpolation in ["lower", "higher"]:
1226+
result = result.cast(pa_dtype)
1227+
12121228
return type(self)(result)
12131229

12141230
def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArrayT:

pandas/tests/extension/test_arrow.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,11 @@ def test_quantile(data, interpolation, quantile, request):
12791279
ser.quantile(q=quantile, interpolation=interpolation)
12801280
return
12811281

1282-
if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)):
1282+
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
1283+
pass
1284+
elif pa.types.is_temporal(data._data.type) and interpolation in ["lower", "higher"]:
1285+
pass
1286+
else:
12831287
request.node.add_marker(
12841288
pytest.mark.xfail(
12851289
raises=pa.ArrowNotImplementedError,
@@ -1293,10 +1297,10 @@ def test_quantile(data, interpolation, quantile, request):
12931297
assert result == data[0]
12941298
else:
12951299
# Just check the values
1296-
result = result.astype("float64[pyarrow]")
1297-
expected = pd.Series(
1298-
data.take([0, 0]).astype("float64[pyarrow]"), index=[0.5, 0.5]
1299-
)
1300+
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
1301+
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
1302+
expected = expected.astype("float64[pyarrow]")
1303+
result = result.astype("float64[pyarrow]")
13001304
tm.assert_series_equal(result, expected)
13011305

13021306

0 commit comments

Comments
 (0)