11
11
Hashable ,
12
12
Literal ,
13
13
Sequence ,
14
- Union ,
15
14
cast ,
16
15
final ,
17
16
)
30
29
ArrayLike ,
31
30
DtypeObj ,
32
31
IndexLabel ,
33
- Scalar ,
34
32
TakeIndexer ,
35
33
npt ,
36
34
)
105
103
)
106
104
from pandas .core .arrays import (
107
105
BaseMaskedArray ,
108
- DatetimeArray ,
109
106
ExtensionArray ,
110
- TimedeltaArray ,
111
107
)
112
108
113
109
@@ -539,13 +535,24 @@ def factorize_array(
539
535
codes : ndarray[np.intp]
540
536
uniques : ndarray
541
537
"""
538
+ original = values
539
+ if values .dtype .kind in ["m" , "M" ]:
540
+ # _get_hashtable_algo will cast dt64/td64 to i8 via _ensure_data, so we
541
+ # need to do the same to na_value. We are assuming here that the passed
542
+ # na_value is an appropriately-typed NaT.
543
+ # e.g. test_where_datetimelike_categorical
544
+ na_value = iNaT
545
+
542
546
hash_klass , values = _get_hashtable_algo (values )
543
547
544
548
table = hash_klass (size_hint or len (values ))
545
549
uniques , codes = table .factorize (
546
550
values , na_sentinel = na_sentinel , na_value = na_value , mask = mask
547
551
)
548
552
553
+ # re-cast e.g. i8->dt64/td64, uint8->bool
554
+ uniques = _reconstruct_data (uniques , original .dtype , original )
555
+
549
556
codes = ensure_platform_int (codes )
550
557
return codes , uniques
551
558
@@ -720,33 +727,18 @@ def factorize(
720
727
isinstance (values , (ABCDatetimeArray , ABCTimedeltaArray ))
721
728
and values .freq is not None
722
729
):
730
+ # The presence of 'freq' means we can fast-path sorting and know there
731
+ # aren't NAs
723
732
codes , uniques = values .factorize (sort = sort )
724
- if isinstance (original , ABCIndex ):
725
- uniques = original ._shallow_copy (uniques , name = None )
726
- elif isinstance (original , ABCSeries ):
727
- from pandas import Index
728
-
729
- uniques = Index (uniques )
730
- return codes , uniques
733
+ return _re_wrap_factorize (original , uniques , codes )
731
734
732
735
if not isinstance (values .dtype , np .dtype ):
733
736
# i.e. ExtensionDtype
734
737
codes , uniques = values .factorize (na_sentinel = na_sentinel )
735
- dtype = original .dtype
736
738
else :
737
- dtype = values .dtype
738
- values = _ensure_data (values )
739
- na_value : Scalar | None
740
-
741
- if original .dtype .kind in ["m" , "M" ]:
742
- # Note: factorize_array will cast NaT bc it has a __int__
743
- # method, but will not cast the more-correct dtype.type("nat")
744
- na_value = iNaT
745
- else :
746
- na_value = None
747
-
739
+ values = np .asarray (values ) # convert DTA/TDA/MultiIndex
748
740
codes , uniques = factorize_array (
749
- values , na_sentinel = na_sentinel , size_hint = size_hint , na_value = na_value
741
+ values , na_sentinel = na_sentinel , size_hint = size_hint
750
742
)
751
743
752
744
if sort and len (uniques ) > 0 :
@@ -759,23 +751,20 @@ def factorize(
759
751
# na_value is set based on the dtype of uniques, and compat set to False is
760
752
# because we do not want na_value to be 0 for integers
761
753
na_value = na_value_for_dtype (uniques .dtype , compat = False )
762
- # Argument 2 to "append" has incompatible type "List[Union[str, float, Period,
763
- # Timestamp, Timedelta, Any]]"; expected "Union[_SupportsArray[dtype[Any]],
764
- # _NestedSequence[_SupportsArray[dtype[Any]]]
765
- # , bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int,
766
- # float, complex, str, bytes]]]" [arg-type]
767
- uniques = np .append (uniques , [na_value ]) # type: ignore[arg-type]
754
+ uniques = np .append (uniques , [na_value ])
768
755
codes = np .where (code_is_na , len (uniques ) - 1 , codes )
769
756
770
- uniques = _reconstruct_data (uniques , dtype , original )
757
+ uniques = _reconstruct_data (uniques , original .dtype , original )
758
+
759
+ return _re_wrap_factorize (original , uniques , codes )
760
+
771
761
772
- # return original tenor
762
+ def _re_wrap_factorize (original , uniques , codes : np .ndarray ):
763
+ """
764
+ Wrap factorize results in Series or Index depending on original type.
765
+ """
773
766
if isinstance (original , ABCIndex ):
774
- if original .dtype .kind in ["m" , "M" ] and isinstance (uniques , np .ndarray ):
775
- original ._data = cast (
776
- "Union[DatetimeArray, TimedeltaArray]" , original ._data
777
- )
778
- uniques = type (original ._data )._simple_new (uniques , dtype = original .dtype )
767
+ uniques = ensure_wrapped_if_datetimelike (uniques )
779
768
uniques = original ._shallow_copy (uniques , name = None )
780
769
elif isinstance (original , ABCSeries ):
781
770
from pandas import Index
0 commit comments