Skip to content

Commit 4985b89

Browse files
authored
REF: simplify factorize, fix factorize_array return type (#46214)
1 parent 5efb570 commit 4985b89

File tree

6 files changed

+34
-57
lines changed

6 files changed

+34
-57
lines changed

pandas/core/algorithms.py

+26-37
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
Hashable,
1212
Literal,
1313
Sequence,
14-
Union,
1514
cast,
1615
final,
1716
)
@@ -30,7 +29,6 @@
3029
ArrayLike,
3130
DtypeObj,
3231
IndexLabel,
33-
Scalar,
3432
TakeIndexer,
3533
npt,
3634
)
@@ -105,9 +103,7 @@
105103
)
106104
from pandas.core.arrays import (
107105
BaseMaskedArray,
108-
DatetimeArray,
109106
ExtensionArray,
110-
TimedeltaArray,
111107
)
112108

113109

@@ -539,13 +535,24 @@ def factorize_array(
539535
codes : ndarray[np.intp]
540536
uniques : ndarray
541537
"""
538+
original = values
539+
if values.dtype.kind in ["m", "M"]:
540+
# _get_hashtable_algo will cast dt64/td64 to i8 via _ensure_data, so we
541+
# need to do the same to na_value. We are assuming here that the passed
542+
# na_value is an appropriately-typed NaT.
543+
# e.g. test_where_datetimelike_categorical
544+
na_value = iNaT
545+
542546
hash_klass, values = _get_hashtable_algo(values)
543547

544548
table = hash_klass(size_hint or len(values))
545549
uniques, codes = table.factorize(
546550
values, na_sentinel=na_sentinel, na_value=na_value, mask=mask
547551
)
548552

553+
# re-cast e.g. i8->dt64/td64, uint8->bool
554+
uniques = _reconstruct_data(uniques, original.dtype, original)
555+
549556
codes = ensure_platform_int(codes)
550557
return codes, uniques
551558

@@ -720,33 +727,18 @@ def factorize(
720727
isinstance(values, (ABCDatetimeArray, ABCTimedeltaArray))
721728
and values.freq is not None
722729
):
730+
# The presence of 'freq' means we can fast-path sorting and know there
731+
# aren't NAs
723732
codes, uniques = values.factorize(sort=sort)
724-
if isinstance(original, ABCIndex):
725-
uniques = original._shallow_copy(uniques, name=None)
726-
elif isinstance(original, ABCSeries):
727-
from pandas import Index
728-
729-
uniques = Index(uniques)
730-
return codes, uniques
733+
return _re_wrap_factorize(original, uniques, codes)
731734

732735
if not isinstance(values.dtype, np.dtype):
733736
# i.e. ExtensionDtype
734737
codes, uniques = values.factorize(na_sentinel=na_sentinel)
735-
dtype = original.dtype
736738
else:
737-
dtype = values.dtype
738-
values = _ensure_data(values)
739-
na_value: Scalar | None
740-
741-
if original.dtype.kind in ["m", "M"]:
742-
# Note: factorize_array will cast NaT bc it has a __int__
743-
# method, but will not cast the more-correct dtype.type("nat")
744-
na_value = iNaT
745-
else:
746-
na_value = None
747-
739+
values = np.asarray(values) # convert DTA/TDA/MultiIndex
748740
codes, uniques = factorize_array(
749-
values, na_sentinel=na_sentinel, size_hint=size_hint, na_value=na_value
741+
values, na_sentinel=na_sentinel, size_hint=size_hint
750742
)
751743

752744
if sort and len(uniques) > 0:
@@ -759,23 +751,20 @@ def factorize(
759751
# na_value is set based on the dtype of uniques, and compat set to False is
760752
# because we do not want na_value to be 0 for integers
761753
na_value = na_value_for_dtype(uniques.dtype, compat=False)
762-
# Argument 2 to "append" has incompatible type "List[Union[str, float, Period,
763-
# Timestamp, Timedelta, Any]]"; expected "Union[_SupportsArray[dtype[Any]],
764-
# _NestedSequence[_SupportsArray[dtype[Any]]]
765-
# , bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int,
766-
# float, complex, str, bytes]]]" [arg-type]
767-
uniques = np.append(uniques, [na_value]) # type: ignore[arg-type]
754+
uniques = np.append(uniques, [na_value])
768755
codes = np.where(code_is_na, len(uniques) - 1, codes)
769756

770-
uniques = _reconstruct_data(uniques, dtype, original)
757+
uniques = _reconstruct_data(uniques, original.dtype, original)
758+
759+
return _re_wrap_factorize(original, uniques, codes)
760+
771761

772-
# return original tenor
762+
def _re_wrap_factorize(original, uniques, codes: np.ndarray):
763+
"""
764+
Wrap factorize results in Series or Index depending on original type.
765+
"""
773766
if isinstance(original, ABCIndex):
774-
if original.dtype.kind in ["m", "M"] and isinstance(uniques, np.ndarray):
775-
original._data = cast(
776-
"Union[DatetimeArray, TimedeltaArray]", original._data
777-
)
778-
uniques = type(original._data)._simple_new(uniques, dtype=original.dtype)
767+
uniques = ensure_wrapped_if_datetimelike(uniques)
779768
uniques = original._shallow_copy(uniques, name=None)
780769
elif isinstance(original, ABCSeries):
781770
from pandas import Index

pandas/core/arrays/_mixins.py

+4
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ def equals(self, other) -> bool:
182182
return False
183183
return bool(array_equivalent(self._ndarray, other._ndarray))
184184

185+
def _from_factorized(cls, values, original):
186+
assert values.dtype == original._ndarray.dtype
187+
return original._from_backing_data(values)
188+
185189
def _values_for_argsort(self) -> np.ndarray:
186190
return self._ndarray
187191

pandas/core/arrays/categorical.py

-6
Original file line numberDiff line numberDiff line change
@@ -2305,12 +2305,6 @@ def unique(self):
23052305
def _values_for_factorize(self):
23062306
return self._ndarray, -1
23072307

2308-
@classmethod
2309-
def _from_factorized(cls, uniques, original):
2310-
# ensure we have the same itemsize for codes
2311-
codes = coerce_indexer_dtype(uniques, original.dtype.categories)
2312-
return original._from_backing_data(codes)
2313-
23142308
def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
23152309
# make sure we have correct itemsize for resulting codes
23162310
res_values = coerce_indexer_dtype(res_values, self.dtype.categories)

pandas/core/arrays/datetimelike.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -550,14 +550,7 @@ def copy(self: DatetimeLikeArrayT, order="C") -> DatetimeLikeArrayT:
550550
return new_obj
551551

552552
def _values_for_factorize(self):
553-
# int64 instead of int ensures we have a "view" method
554-
return self._ndarray, np.int64(iNaT)
555-
556-
@classmethod
557-
def _from_factorized(
558-
cls: type[DatetimeLikeArrayT], values, original: DatetimeLikeArrayT
559-
) -> DatetimeLikeArrayT:
560-
return cls(values, dtype=original.dtype)
553+
return self._ndarray, self._internal_fill_value
561554

562555
# ------------------------------------------------------------------
563556
# Validation Methods

pandas/core/arrays/masked.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -874,8 +874,9 @@ def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
874874

875875
codes, uniques = factorize_array(arr, na_sentinel=na_sentinel, mask=mask)
876876

877-
# the hashtables don't handle all different types of bits
878-
uniques = uniques.astype(self.dtype.numpy_dtype, copy=False)
877+
# check that factorize_array correctly preserves dtype.
878+
assert uniques.dtype == self.dtype.numpy_dtype, (uniques.dtype, self.dtype)
879+
879880
uniques_ea = type(self)(uniques, np.zeros(len(uniques), dtype=bool))
880881
return codes, uniques_ea
881882

pandas/core/arrays/numpy_.py

-4
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,6 @@ def _from_sequence(
109109
result = result.copy()
110110
return cls(result)
111111

112-
@classmethod
113-
def _from_factorized(cls, values, original) -> PandasArray:
114-
return original._from_backing_data(values)
115-
116112
def _from_backing_data(self, arr: np.ndarray) -> PandasArray:
117113
return type(self)(arr)
118114

0 commit comments

Comments
 (0)