78
78
)
79
79
80
80
from pandas import (
81
+ ArrowDtype ,
81
82
Categorical ,
82
83
Index ,
83
84
MultiIndex ,
84
85
Series ,
85
86
)
86
87
import pandas .core .algorithms as algos
87
88
from pandas .core .arrays import (
89
+ ArrowExtensionArray ,
88
90
BaseMaskedArray ,
89
91
ExtensionArray ,
90
92
)
@@ -2372,7 +2374,11 @@ def _factorize_keys(
2372
2374
rk = ensure_int64 (rk .codes )
2373
2375
2374
2376
elif isinstance (lk , ExtensionArray ) and is_dtype_equal (lk .dtype , rk .dtype ):
2375
- if not isinstance (lk , BaseMaskedArray ):
2377
+ if not isinstance (lk , BaseMaskedArray ) and not (
2378
+ # exclude arrow dtypes that would get cast to object
2379
+ isinstance (lk .dtype , ArrowDtype )
2380
+ and is_numeric_dtype (lk .dtype .numpy_dtype )
2381
+ ):
2376
2382
lk , _ = lk ._values_for_factorize ()
2377
2383
2378
2384
# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute
@@ -2387,6 +2393,16 @@ def _factorize_keys(
2387
2393
assert isinstance (rk , BaseMaskedArray )
2388
2394
llab = rizer .factorize (lk ._data , mask = lk ._mask )
2389
2395
rlab = rizer .factorize (rk ._data , mask = rk ._mask )
2396
+ elif isinstance (lk , ArrowExtensionArray ):
2397
+ assert isinstance (rk , ArrowExtensionArray )
2398
+ # we can only get here with numeric dtypes
2399
+ # TODO: Remove when we have a Factorizer for Arrow
2400
+ llab = rizer .factorize (
2401
+ lk .to_numpy (na_value = 1 , dtype = lk .dtype .numpy_dtype ), mask = lk .isna ()
2402
+ )
2403
+ rlab = rizer .factorize (
2404
+ rk .to_numpy (na_value = 1 , dtype = lk .dtype .numpy_dtype ), mask = rk .isna ()
2405
+ )
2390
2406
else :
2391
2407
# Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type
2392
2408
# "Union[ndarray[Any, dtype[signedinteger[_64Bit]]],
@@ -2445,6 +2461,8 @@ def _convert_arrays_and_get_rizer_klass(
2445
2461
# Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]";
2446
2462
# expected type "Type[object]"
2447
2463
klass = _factorizers [lk .dtype .type ] # type: ignore[index]
2464
+ elif isinstance (lk .dtype , ArrowDtype ):
2465
+ klass = _factorizers [lk .dtype .numpy_dtype .type ]
2448
2466
else :
2449
2467
klass = _factorizers [lk .dtype .type ]
2450
2468
0 commit comments