Skip to content

Commit f10ba02

Browse files
jbrockmendelpooja-subramaniam
authored andcommitted
ENH: support argmin/argmax with pyarrow durations (pandas-dev#50879)
1 parent 5d7bd5a commit f10ba02

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

pandas/core/arrays/arrow/array.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,12 @@ def _argmin_max(self, skipna: bool, method: str) -> int:
528528
f"arg{method} only implemented for pyarrow version >= 6.0"
529529
)
530530

531-
value = getattr(pc, method)(self._data, skip_nulls=skipna)
532-
return pc.index(self._data, value).as_py()
531+
data = self._data
532+
if pa.types.is_duration(data.type):
533+
data = data.cast(pa.int64())
534+
535+
value = getattr(pc, method)(data, skip_nulls=skipna)
536+
return pc.index(data, value).as_py()
533537

534538
def argmin(self, skipna: bool = True) -> int:
535539
return self._argmin_max(skipna, "min")

pandas/tests/extension/test_arrow.py

-15
Original file line numberDiff line numberDiff line change
@@ -869,13 +869,6 @@ def test_argmin_argmax(
869869
reason=f"{pa_dtype} only has 2 unique possible values",
870870
)
871871
)
872-
elif pa.types.is_duration(pa_dtype):
873-
request.node.add_marker(
874-
pytest.mark.xfail(
875-
raises=pa.ArrowNotImplementedError,
876-
reason=f"min_max not supported in pyarrow for {pa_dtype}",
877-
)
878-
)
879872
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
880873

881874
@pytest.mark.parametrize(
@@ -894,21 +887,13 @@ def test_argmin_argmax(
894887
def test_argreduce_series(
895888
self, data_missing_for_sorting, op_name, skipna, expected, request
896889
):
897-
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
898890
if pa_version_under6p0 and skipna:
899891
request.node.add_marker(
900892
pytest.mark.xfail(
901893
raises=NotImplementedError,
902894
reason="min_max not supported in pyarrow",
903895
)
904896
)
905-
elif not pa_version_under6p0 and pa.types.is_duration(pa_dtype) and skipna:
906-
request.node.add_marker(
907-
pytest.mark.xfail(
908-
raises=pa.ArrowNotImplementedError,
909-
reason=f"min_max not supported in pyarrow for {pa_dtype}",
910-
)
911-
)
912897
super().test_argreduce_series(
913898
data_missing_for_sorting, op_name, skipna, expected
914899
)

0 commit comments

Comments
 (0)