@@ -660,12 +660,11 @@ def factorize(
660
660
pa_type = self ._data .type
661
661
if pa .types .is_duration (pa_type ):
662
662
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
663
- arr = cast (ArrowExtensionArray , self .astype ("int64[pyarrow]" ))
664
- indices , uniques = arr .factorize (use_na_sentinel = use_na_sentinel )
665
- uniques = uniques .astype (self .dtype )
666
- return indices , uniques
663
+ data = self ._data .cast (pa .int64 ())
664
+ else :
665
+ data = self ._data
667
666
668
- encoded = self . _data .dictionary_encode (null_encoding = null_encoding )
667
+ encoded = data .dictionary_encode (null_encoding = null_encoding )
669
668
if encoded .length () == 0 :
670
669
indices = np .array ([], dtype = np .intp )
671
670
uniques = type (self )(pa .chunked_array ([], type = encoded .type .value_type ))
@@ -677,6 +676,9 @@ def factorize(
677
676
np .intp , copy = False
678
677
)
679
678
uniques = type (self )(encoded .chunk (0 ).dictionary )
679
+
680
+ if pa .types .is_duration (pa_type ):
681
+ uniques = cast (ArrowExtensionArray , uniques .astype (self .dtype ))
680
682
return indices , uniques
681
683
682
684
def reshape (self , * args , ** kwargs ):
@@ -861,13 +863,20 @@ def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
861
863
-------
862
864
ArrowExtensionArray
863
865
"""
864
- if pa .types .is_duration (self ._data .type ):
866
+ pa_type = self ._data .type
867
+
868
+ if pa .types .is_duration (pa_type ):
865
869
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
866
- arr = cast (ArrowExtensionArrayT , self .astype ("int64[pyarrow]" ))
867
- result = arr .unique ()
868
- return cast (ArrowExtensionArrayT , result .astype (self .dtype ))
870
+ data = self ._data .cast (pa .int64 ())
871
+ else :
872
+ data = self ._data
873
+
874
+ pa_result = pc .unique (data )
869
875
870
- return type (self )(pc .unique (self ._data ))
876
+ if pa .types .is_duration (pa_type ):
877
+ pa_result = pa_result .cast (pa_type )
878
+
879
+ return type (self )(pa_result )
871
880
872
881
def value_counts (self , dropna : bool = True ) -> Series :
873
882
"""
@@ -886,27 +895,30 @@ def value_counts(self, dropna: bool = True) -> Series:
886
895
--------
887
896
Series.value_counts
888
897
"""
889
- if pa .types .is_duration (self ._data .type ):
898
+ pa_type = self ._data .type
899
+ if pa .types .is_duration (pa_type ):
890
900
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
891
- arr = cast (ArrowExtensionArray , self .astype ("int64[pyarrow]" ))
892
- result = arr .value_counts ()
893
- result .index = result .index .astype (self .dtype )
894
- return result
901
+ data = self ._data .cast (pa .int64 ())
902
+ else :
903
+ data = self ._data
895
904
896
905
from pandas import (
897
906
Index ,
898
907
Series ,
899
908
)
900
909
901
- vc = self . _data .value_counts ()
910
+ vc = data .value_counts ()
902
911
903
912
values = vc .field (0 )
904
913
counts = vc .field (1 )
905
- if dropna and self . _data .null_count > 0 :
914
+ if dropna and data .null_count > 0 :
906
915
mask = values .is_valid ()
907
916
values = values .filter (mask )
908
917
counts = counts .filter (mask )
909
918
919
+ if pa .types .is_duration (pa_type ):
920
+ values = values .cast (pa_type )
921
+
910
922
# No missing values so we can adhere to the interface and return a numpy array.
911
923
counts = np .array (counts )
912
924
@@ -1214,12 +1226,29 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra
1214
1226
"""
1215
1227
if pa_version_under6p0 :
1216
1228
raise NotImplementedError ("mode only supported for pyarrow version >= 6.0" )
1217
- modes = pc .mode (self ._data , pc .count_distinct (self ._data ).as_py ())
1229
+
1230
+ pa_type = self ._data .type
1231
+ if pa .types .is_temporal (pa_type ):
1232
+ nbits = pa_type .bit_width
1233
+ if nbits == 32 :
1234
+ data = self ._data .cast (pa .int32 ())
1235
+ elif nbits == 64 :
1236
+ data = self ._data .cast (pa .int64 ())
1237
+ else :
1238
+ raise NotImplementedError (pa_type )
1239
+ else :
1240
+ data = self ._data
1241
+
1242
+ modes = pc .mode (data , pc .count_distinct (data ).as_py ())
1218
1243
values = modes .field (0 )
1219
1244
counts = modes .field (1 )
1220
1245
# counts sorted descending i.e counts[0] = max
1221
1246
mask = pc .equal (counts , counts [0 ])
1222
1247
most_common = values .filter (mask )
1248
+
1249
+ if pa .types .is_temporal (pa_type ):
1250
+ most_common = most_common .cast (pa_type )
1251
+
1223
1252
return type (self )(most_common )
1224
1253
1225
1254
def _maybe_convert_setitem_value (self , value ):
0 commit comments