@@ -3395,6 +3395,7 @@ def join(self, other, how="left", level=None, return_indexers=False, sort=False)
3395
3395
-------
3396
3396
join_index, (left_indexer, right_indexer)
3397
3397
"""
3398
+ other = ensure_index (other )
3398
3399
self_is_mi = isinstance (self , ABCMultiIndex )
3399
3400
other_is_mi = isinstance (other , ABCMultiIndex )
3400
3401
@@ -3414,8 +3415,6 @@ def join(self, other, how="left", level=None, return_indexers=False, sort=False)
3414
3415
other , level , how = how , return_indexers = return_indexers
3415
3416
)
3416
3417
3417
- other = ensure_index (other )
3418
-
3419
3418
if len (other ) == 0 and how in ("left" , "outer" ):
3420
3419
join_index = self ._shallow_copy ()
3421
3420
if return_indexers :
@@ -3577,16 +3576,26 @@ def _join_multi(self, other, how, return_indexers=True):
3577
3576
def _join_non_unique (self , other , how = "left" , return_indexers = False ):
3578
3577
from pandas .core .reshape .merge import _get_join_indexers
3579
3578
3579
+ # We only get here if dtypes match
3580
+ assert self .dtype == other .dtype
3581
+
3582
+ if is_extension_array_dtype (self .dtype ):
3583
+ lvalues = self ._data ._values_for_argsort ()
3584
+ rvalues = other ._data ._values_for_argsort ()
3585
+ else :
3586
+ lvalues = self ._values
3587
+ rvalues = other ._values
3588
+
3580
3589
left_idx , right_idx = _get_join_indexers (
3581
- [self . _ndarray_values ], [other . _ndarray_values ], how = how , sort = True
3590
+ [lvalues ], [rvalues ], how = how , sort = True
3582
3591
)
3583
3592
3584
3593
left_idx = ensure_platform_int (left_idx )
3585
3594
right_idx = ensure_platform_int (right_idx )
3586
3595
3587
- join_index = np .asarray (self . _ndarray_values .take (left_idx ))
3596
+ join_index = np .asarray (lvalues .take (left_idx ))
3588
3597
mask = left_idx == - 1
3589
- np .putmask (join_index , mask , other . _ndarray_values .take (right_idx ))
3598
+ np .putmask (join_index , mask , rvalues .take (right_idx ))
3590
3599
3591
3600
join_index = self ._wrap_joined_index (join_index , other )
3592
3601
@@ -3737,15 +3746,22 @@ def _get_leaf_sorter(labels):
3737
3746
return join_index
3738
3747
3739
3748
def _join_monotonic (self , other , how = "left" , return_indexers = False ):
3749
+ # We only get here with matching dtypes
3750
+ assert other .dtype == self .dtype
3751
+
3740
3752
if self .equals (other ):
3741
3753
ret_index = other if how == "right" else self
3742
3754
if return_indexers :
3743
3755
return ret_index , None , None
3744
3756
else :
3745
3757
return ret_index
3746
3758
3747
- sv = self ._ndarray_values
3748
- ov = other ._ndarray_values
3759
+ if is_extension_array_dtype (self .dtype ):
3760
+ sv = self ._data ._values_for_argsort ()
3761
+ ov = other ._data ._values_for_argsort ()
3762
+ else :
3763
+ sv = self ._values
3764
+ ov = other ._values
3749
3765
3750
3766
if self .is_unique and other .is_unique :
3751
3767
# We can perform much better than the general case
0 commit comments