Skip to content

Commit 0105aa2

Browse files
authored
BUG/ENH: fix pyarrow quantile xfails (#50983)
1 parent 38e802c commit 0105aa2

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

pandas/core/arrays/arrow/array.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,7 @@ def _quantile(
12551255
pa_dtype = self._data.type
12561256

12571257
data = self._data
1258-
if pa.types.is_temporal(pa_dtype) and interpolation in ["lower", "higher"]:
1258+
if pa.types.is_temporal(pa_dtype):
12591259
# https://github.com/apache/arrow/issues/33769 in these cases
12601260
# we can cast to ints and back
12611261
nbits = pa_dtype.bit_width
@@ -1266,7 +1266,12 @@ def _quantile(
12661266

12671267
result = pc.quantile(data, q=qs, interpolation=interpolation)
12681268

1269-
if pa.types.is_temporal(pa_dtype) and interpolation in ["lower", "higher"]:
1269+
if pa.types.is_temporal(pa_dtype):
1270+
nbits = pa_dtype.bit_width
1271+
if nbits == 32:
1272+
result = result.cast(pa.int32())
1273+
else:
1274+
result = result.cast(pa.int64())
12701275
result = result.cast(pa_dtype)
12711276

12721277
return type(self)(result)

pandas/tests/extension/test_arrow.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1333,7 +1333,7 @@ def test_quantile(data, interpolation, quantile, request):
13331333

13341334
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
13351335
pass
1336-
elif pa.types.is_temporal(data._data.type) and interpolation in ["lower", "higher"]:
1336+
elif pa.types.is_temporal(data._data.type):
13371337
pass
13381338
else:
13391339
request.node.add_marker(
@@ -1345,6 +1345,28 @@ def test_quantile(data, interpolation, quantile, request):
13451345
data = data.take([0, 0, 0])
13461346
ser = pd.Series(data)
13471347
result = ser.quantile(q=quantile, interpolation=interpolation)
1348+
1349+
if pa.types.is_timestamp(pa_dtype) and interpolation not in ["lower", "higher"]:
1350+
# rounding error will make the check below fail
1351+
# (e.g. '2020-01-01 01:01:01.000001' vs '2020-01-01 01:01:01.000001024'),
1352+
# so we'll check for now that we match the numpy analogue
1353+
if pa_dtype.tz:
1354+
pd_dtype = f"M8[{pa_dtype.unit}, {pa_dtype.tz}]"
1355+
else:
1356+
pd_dtype = f"M8[{pa_dtype.unit}]"
1357+
ser_np = ser.astype(pd_dtype)
1358+
1359+
expected = ser_np.quantile(q=quantile, interpolation=interpolation)
1360+
if quantile == 0.5:
1361+
if pa_dtype.unit == "us":
1362+
expected = expected.to_pydatetime(warn=False)
1363+
assert result == expected
1364+
else:
1365+
if pa_dtype.unit == "us":
1366+
expected = expected.dt.floor("us")
1367+
tm.assert_series_equal(result, expected.astype(data.dtype))
1368+
return
1369+
13481370
if quantile == 0.5:
13491371
assert result == data[0]
13501372
else:

0 commit comments

Comments
 (0)