diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index e818e367ca83d..b6fa5da857910 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -2364,35 +2364,7 @@ def _factorize_keys( # "_values_for_factorize" rk, _ = rk._values_for_factorize() # type: ignore[union-attr] - klass: type[libhashtable.Factorizer] - if is_numeric_dtype(lk.dtype): - if not is_dtype_equal(lk, rk): - dtype = find_common_type([lk.dtype, rk.dtype]) - if isinstance(dtype, ExtensionDtype): - cls = dtype.construct_array_type() - if not isinstance(lk, ExtensionArray): - lk = cls._from_sequence(lk, dtype=dtype, copy=False) - else: - lk = lk.astype(dtype) - - if not isinstance(rk, ExtensionArray): - rk = cls._from_sequence(rk, dtype=dtype, copy=False) - else: - rk = rk.astype(dtype) - else: - lk = lk.astype(dtype) - rk = rk.astype(dtype) - if isinstance(lk, BaseMaskedArray): - # Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]"; - # expected type "Type[object]" - klass = _factorizers[lk.dtype.type] # type: ignore[index] - else: - klass = _factorizers[lk.dtype.type] - - else: - klass = libhashtable.ObjectFactorizer - lk = ensure_object(lk) - rk = ensure_object(rk) + klass, lk, rk = _convert_arrays_and_get_rizer_klass(lk, rk) rizer = klass(max(len(lk), len(rk))) @@ -2433,6 +2405,41 @@ def _factorize_keys( return llab, rlab, count +def _convert_arrays_and_get_rizer_klass( + lk: ArrayLike, rk: ArrayLike +) -> tuple[type[libhashtable.Factorizer], ArrayLike, ArrayLike]: + klass: type[libhashtable.Factorizer] + if is_numeric_dtype(lk.dtype): + if not is_dtype_equal(lk, rk): + dtype = find_common_type([lk.dtype, rk.dtype]) + if isinstance(dtype, ExtensionDtype): + cls = dtype.construct_array_type() + if not isinstance(lk, ExtensionArray): + lk = cls._from_sequence(lk, dtype=dtype, copy=False) + else: + lk = lk.astype(dtype) + + if not isinstance(rk, ExtensionArray): + rk = cls._from_sequence(rk, dtype=dtype, copy=False) + else: + rk = rk.astype(dtype) + else: + lk = lk.astype(dtype) + rk = rk.astype(dtype) + if isinstance(lk, BaseMaskedArray): + # Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]"; + # expected type "Type[object]" + klass = _factorizers[lk.dtype.type] # type: ignore[index] + else: + klass = _factorizers[lk.dtype.type] + + else: + klass = libhashtable.ObjectFactorizer + lk = ensure_object(lk) + rk = ensure_object(rk) + return klass, lk, rk + + def _sort_labels( uniques: np.ndarray, left: npt.NDArray[np.intp], right: npt.NDArray[np.intp] ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: