80
80
)
81
81
82
82
from pandas import (
83
+ ArrowDtype ,
83
84
Categorical ,
84
85
Index ,
85
86
MultiIndex ,
86
87
Series ,
87
88
)
88
89
import pandas .core .algorithms as algos
89
90
from pandas .core .arrays import (
91
+ ArrowExtensionArray ,
90
92
BaseMaskedArray ,
91
93
ExtensionArray ,
92
94
)
@@ -2377,7 +2379,11 @@ def _factorize_keys(
2377
2379
rk = ensure_int64 (rk .codes )
2378
2380
2379
2381
elif isinstance (lk , ExtensionArray ) and is_dtype_equal (lk .dtype , rk .dtype ):
2380
- if not isinstance (lk , BaseMaskedArray ):
2382
+ if not isinstance (lk , BaseMaskedArray ) and not (
2383
+ # exclude arrow dtypes that would get cast to object
2384
+ isinstance (lk .dtype , ArrowDtype )
2385
+ and is_numeric_dtype (lk .dtype .numpy_dtype )
2386
+ ):
2381
2387
lk , _ = lk ._values_for_factorize ()
2382
2388
2383
2389
# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute
@@ -2392,6 +2398,16 @@ def _factorize_keys(
2392
2398
assert isinstance (rk , BaseMaskedArray )
2393
2399
llab = rizer .factorize (lk ._data , mask = lk ._mask )
2394
2400
rlab = rizer .factorize (rk ._data , mask = rk ._mask )
2401
+ elif isinstance (lk , ArrowExtensionArray ):
2402
+ assert isinstance (rk , ArrowExtensionArray )
2403
+ # we can only get here with numeric dtypes
2404
+ # TODO: Remove when we have a Factorizer for Arrow
2405
+ llab = rizer .factorize (
2406
+ lk .to_numpy (na_value = 1 , dtype = lk .dtype .numpy_dtype ), mask = lk .isna ()
2407
+ )
2408
+ rlab = rizer .factorize (
2409
+ rk .to_numpy (na_value = 1 , dtype = lk .dtype .numpy_dtype ), mask = rk .isna ()
2410
+ )
2395
2411
else :
2396
2412
# Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type
2397
2413
# "Union[ndarray[Any, dtype[signedinteger[_64Bit]]],
@@ -2450,6 +2466,8 @@ def _convert_arrays_and_get_rizer_klass(
2450
2466
# Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]";
2451
2467
# expected type "Type[object]"
2452
2468
klass = _factorizers [lk .dtype .type ] # type: ignore[index]
2469
+ elif isinstance (lk .dtype , ArrowDtype ):
2470
+ klass = _factorizers [lk .dtype .numpy_dtype .type ]
2453
2471
else :
2454
2472
klass = _factorizers [lk .dtype .type ]
2455
2473
0 commit comments