Skip to content

Commit ac3c010

Browse files
authored
TST: simplify pyarrow tests, make mode work with temporal dtypes (pandas-dev#50688)
* TST: simplify pyarrow tests, make mode work with temporal dtypes * mypy, min_version fixups * use pa cast * mypy fixup * use cast instead of astype
1 parent 5d11658 commit ac3c010

File tree

2 files changed

+173
-175
lines changed

2 files changed

+173
-175
lines changed

pandas/core/arrays/arrow/array.py

+47-18
Original file line numberDiff line numberDiff line change
@@ -660,12 +660,11 @@ def factorize(
660660
pa_type = self._data.type
661661
if pa.types.is_duration(pa_type):
662662
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
663-
arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]"))
664-
indices, uniques = arr.factorize(use_na_sentinel=use_na_sentinel)
665-
uniques = uniques.astype(self.dtype)
666-
return indices, uniques
663+
data = self._data.cast(pa.int64())
664+
else:
665+
data = self._data
667666

668-
encoded = self._data.dictionary_encode(null_encoding=null_encoding)
667+
encoded = data.dictionary_encode(null_encoding=null_encoding)
669668
if encoded.length() == 0:
670669
indices = np.array([], dtype=np.intp)
671670
uniques = type(self)(pa.chunked_array([], type=encoded.type.value_type))
@@ -677,6 +676,9 @@ def factorize(
677676
np.intp, copy=False
678677
)
679678
uniques = type(self)(encoded.chunk(0).dictionary)
679+
680+
if pa.types.is_duration(pa_type):
681+
uniques = cast(ArrowExtensionArray, uniques.astype(self.dtype))
680682
return indices, uniques
681683

682684
def reshape(self, *args, **kwargs):
@@ -861,13 +863,20 @@ def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
861863
-------
862864
ArrowExtensionArray
863865
"""
864-
if pa.types.is_duration(self._data.type):
866+
pa_type = self._data.type
867+
868+
if pa.types.is_duration(pa_type):
865869
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
866-
arr = cast(ArrowExtensionArrayT, self.astype("int64[pyarrow]"))
867-
result = arr.unique()
868-
return cast(ArrowExtensionArrayT, result.astype(self.dtype))
870+
data = self._data.cast(pa.int64())
871+
else:
872+
data = self._data
873+
874+
pa_result = pc.unique(data)
869875

870-
return type(self)(pc.unique(self._data))
876+
if pa.types.is_duration(pa_type):
877+
pa_result = pa_result.cast(pa_type)
878+
879+
return type(self)(pa_result)
871880

872881
def value_counts(self, dropna: bool = True) -> Series:
873882
"""
@@ -886,27 +895,30 @@ def value_counts(self, dropna: bool = True) -> Series:
886895
--------
887896
Series.value_counts
888897
"""
889-
if pa.types.is_duration(self._data.type):
898+
pa_type = self._data.type
899+
if pa.types.is_duration(pa_type):
890900
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
891-
arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]"))
892-
result = arr.value_counts()
893-
result.index = result.index.astype(self.dtype)
894-
return result
901+
data = self._data.cast(pa.int64())
902+
else:
903+
data = self._data
895904

896905
from pandas import (
897906
Index,
898907
Series,
899908
)
900909

901-
vc = self._data.value_counts()
910+
vc = data.value_counts()
902911

903912
values = vc.field(0)
904913
counts = vc.field(1)
905-
if dropna and self._data.null_count > 0:
914+
if dropna and data.null_count > 0:
906915
mask = values.is_valid()
907916
values = values.filter(mask)
908917
counts = counts.filter(mask)
909918

919+
if pa.types.is_duration(pa_type):
920+
values = values.cast(pa_type)
921+
910922
# No missing values so we can adhere to the interface and return a numpy array.
911923
counts = np.array(counts)
912924

@@ -1214,12 +1226,29 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra
12141226
"""
12151227
if pa_version_under6p0:
12161228
raise NotImplementedError("mode only supported for pyarrow version >= 6.0")
1217-
modes = pc.mode(self._data, pc.count_distinct(self._data).as_py())
1229+
1230+
pa_type = self._data.type
1231+
if pa.types.is_temporal(pa_type):
1232+
nbits = pa_type.bit_width
1233+
if nbits == 32:
1234+
data = self._data.cast(pa.int32())
1235+
elif nbits == 64:
1236+
data = self._data.cast(pa.int64())
1237+
else:
1238+
raise NotImplementedError(pa_type)
1239+
else:
1240+
data = self._data
1241+
1242+
modes = pc.mode(data, pc.count_distinct(data).as_py())
12181243
values = modes.field(0)
12191244
counts = modes.field(1)
12201245
# counts sorted descending i.e counts[0] = max
12211246
mask = pc.equal(counts, counts[0])
12221247
most_common = values.filter(mask)
1248+
1249+
if pa.types.is_temporal(pa_type):
1250+
most_common = most_common.cast(pa_type)
1251+
12231252
return type(self)(most_common)
12241253

12251254
def _maybe_convert_setitem_value(self, value):

0 commit comments

Comments
 (0)