Skip to content

Commit 384a603

Browse files
authored
ENH: fix a bunch of pyarrow duration xfails (#50669)
* ENH: fix a bunch of pyarrow duration xfails * mypy fixup
1 parent a38a24e commit 384a603

File tree

2 files changed

+23
-110
lines changed

2 files changed

+23
-110
lines changed

pandas/core/arrays/arrow/array.py

+22
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,15 @@ def factorize(
653653
use_na_sentinel: bool = True,
654654
) -> tuple[np.ndarray, ExtensionArray]:
655655
null_encoding = "mask" if use_na_sentinel else "encode"
656+
657+
pa_type = self._data.type
658+
if pa.types.is_duration(pa_type):
659+
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
660+
arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]"))
661+
indices, uniques = arr.factorize(use_na_sentinel=use_na_sentinel)
662+
uniques = uniques.astype(self.dtype)
663+
return indices, uniques
664+
656665
encoded = self._data.dictionary_encode(null_encoding=null_encoding)
657666
if encoded.length() == 0:
658667
indices = np.array([], dtype=np.intp)
@@ -849,6 +858,12 @@ def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
849858
-------
850859
ArrowExtensionArray
851860
"""
861+
if pa.types.is_duration(self._data.type):
862+
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
863+
arr = cast(ArrowExtensionArrayT, self.astype("int64[pyarrow]"))
864+
result = arr.unique()
865+
return cast(ArrowExtensionArrayT, result.astype(self.dtype))
866+
852867
return type(self)(pc.unique(self._data))
853868

854869
def value_counts(self, dropna: bool = True) -> Series:
@@ -868,6 +883,13 @@ def value_counts(self, dropna: bool = True) -> Series:
868883
--------
869884
Series.value_counts
870885
"""
886+
if pa.types.is_duration(self._data.type):
887+
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
888+
arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]"))
889+
result = arr.value_counts()
890+
result.index = result.index.astype(self.dtype)
891+
return result
892+
871893
from pandas import (
872894
Index,
873895
Series,

pandas/tests/extension/test_arrow.py

+1-110
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,6 @@ def test_groupby_extension_no_sort(self, data_for_grouping, request):
508508
reason=f"{pa_dtype} only has 2 unique possible values",
509509
)
510510
)
511-
elif pa.types.is_duration(pa_dtype):
512-
request.node.add_marker(
513-
pytest.mark.xfail(
514-
raises=pa.ArrowNotImplementedError,
515-
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
516-
)
517-
)
518511
super().test_groupby_extension_no_sort(data_for_grouping)
519512

520513
def test_groupby_extension_transform(self, data_for_grouping, request):
@@ -525,13 +518,6 @@ def test_groupby_extension_transform(self, data_for_grouping, request):
525518
reason=f"{pa_dtype} only has 2 unique possible values",
526519
)
527520
)
528-
elif pa.types.is_duration(pa_dtype):
529-
request.node.add_marker(
530-
pytest.mark.xfail(
531-
raises=pa.ArrowNotImplementedError,
532-
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
533-
)
534-
)
535521
with tm.maybe_produces_warning(
536522
PerformanceWarning,
537523
pa_version_under7p0 and not pa.types.is_duration(pa_dtype),
@@ -542,14 +528,6 @@ def test_groupby_extension_transform(self, data_for_grouping, request):
542528
def test_groupby_extension_apply(
543529
self, data_for_grouping, groupby_apply_op, request
544530
):
545-
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
546-
if pa.types.is_duration(pa_dtype):
547-
request.node.add_marker(
548-
pytest.mark.xfail(
549-
raises=pa.ArrowNotImplementedError,
550-
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
551-
)
552-
)
553531
with tm.maybe_produces_warning(
554532
PerformanceWarning,
555533
pa_version_under7p0 and not pa.types.is_duration(pa_dtype),
@@ -567,13 +545,6 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
567545
reason=f"{pa_dtype} only has 2 unique possible values",
568546
)
569547
)
570-
elif pa.types.is_duration(pa_dtype):
571-
request.node.add_marker(
572-
pytest.mark.xfail(
573-
raises=pa.ArrowNotImplementedError,
574-
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
575-
)
576-
)
577548
with tm.maybe_produces_warning(
578549
PerformanceWarning,
579550
pa_version_under7p0 and not pa.types.is_duration(pa_dtype),
@@ -796,25 +767,9 @@ def test_diff(self, data, periods, request):
796767
@pytest.mark.filterwarnings("ignore:Falling back:pandas.errors.PerformanceWarning")
797768
@pytest.mark.parametrize("dropna", [True, False])
798769
def test_value_counts(self, all_data, dropna, request):
799-
pa_dtype = all_data.dtype.pyarrow_dtype
800-
if pa.types.is_duration(pa_dtype):
801-
request.node.add_marker(
802-
pytest.mark.xfail(
803-
raises=pa.ArrowNotImplementedError,
804-
reason=f"value_count has no kernel for {pa_dtype}",
805-
)
806-
)
807770
super().test_value_counts(all_data, dropna)
808771

809772
def test_value_counts_with_normalize(self, data, request):
810-
pa_dtype = data.dtype.pyarrow_dtype
811-
if pa.types.is_duration(pa_dtype):
812-
request.node.add_marker(
813-
pytest.mark.xfail(
814-
raises=pa.ArrowNotImplementedError,
815-
reason=f"value_count has no pyarrow kernel for {pa_dtype}",
816-
)
817-
)
818773
with tm.maybe_produces_warning(
819774
PerformanceWarning,
820775
pa_version_under7p0 and not pa.types.is_duration(pa_dtype),
@@ -896,17 +851,6 @@ def test_nargsort(self, data_missing_for_sorting, na_position, expected):
896851

897852
@pytest.mark.parametrize("ascending", [True, False])
898853
def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request):
899-
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
900-
if pa.types.is_duration(pa_dtype) and not ascending:
901-
request.node.add_marker(
902-
pytest.mark.xfail(
903-
raises=pa.ArrowNotImplementedError,
904-
reason=(
905-
f"unique has no pyarrow kernel "
906-
f"for {pa_dtype} when ascending={ascending}"
907-
),
908-
)
909-
)
910854
with tm.maybe_produces_warning(
911855
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
912856
):
@@ -925,76 +869,23 @@ def test_sort_values_missing(
925869

926870
@pytest.mark.parametrize("ascending", [True, False])
927871
def test_sort_values_frame(self, data_for_sorting, ascending, request):
928-
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
929-
if pa.types.is_duration(pa_dtype):
930-
request.node.add_marker(
931-
pytest.mark.xfail(
932-
raises=pa.ArrowNotImplementedError,
933-
reason=(
934-
f"dictionary_encode has no pyarrow kernel "
935-
f"for {pa_dtype} when ascending={ascending}"
936-
),
937-
)
938-
)
939872
with tm.maybe_produces_warning(
940873
PerformanceWarning,
941874
pa_version_under7p0 and not pa.types.is_duration(pa_dtype),
942875
check_stacklevel=False,
943876
):
944877
super().test_sort_values_frame(data_for_sorting, ascending)
945878

946-
@pytest.mark.parametrize("box", [pd.Series, lambda x: x])
947-
@pytest.mark.parametrize("method", [lambda x: x.unique(), pd.unique])
948-
def test_unique(self, data, box, method, request):
949-
pa_dtype = data.dtype.pyarrow_dtype
950-
if pa.types.is_duration(pa_dtype):
951-
request.node.add_marker(
952-
pytest.mark.xfail(
953-
raises=pa.ArrowNotImplementedError,
954-
reason=f"unique has no pyarrow kernel for {pa_dtype}.",
955-
)
956-
)
957-
super().test_unique(data, box, method)
958-
959879
def test_factorize(self, data_for_grouping, request):
960880
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
961-
if pa.types.is_duration(pa_dtype):
962-
request.node.add_marker(
963-
pytest.mark.xfail(
964-
raises=pa.ArrowNotImplementedError,
965-
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
966-
)
967-
)
968-
elif pa.types.is_boolean(pa_dtype):
881+
if pa.types.is_boolean(pa_dtype):
969882
request.node.add_marker(
970883
pytest.mark.xfail(
971884
reason=f"{pa_dtype} only has 2 unique possible values",
972885
)
973886
)
974887
super().test_factorize(data_for_grouping)
975888

976-
def test_factorize_equivalence(self, data_for_grouping, request):
977-
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
978-
if pa.types.is_duration(pa_dtype):
979-
request.node.add_marker(
980-
pytest.mark.xfail(
981-
raises=pa.ArrowNotImplementedError,
982-
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
983-
)
984-
)
985-
super().test_factorize_equivalence(data_for_grouping)
986-
987-
def test_factorize_empty(self, data, request):
988-
pa_dtype = data.dtype.pyarrow_dtype
989-
if pa.types.is_duration(pa_dtype):
990-
request.node.add_marker(
991-
pytest.mark.xfail(
992-
raises=pa.ArrowNotImplementedError,
993-
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
994-
)
995-
)
996-
super().test_factorize_empty(data)
997-
998889
@pytest.mark.xfail(
999890
reason="result dtype pyarrow[bool] better than expected dtype object"
1000891
)

0 commit comments

Comments
 (0)