Skip to content

Commit 19dcd0f

Browse files
jbrockmendelluckyvs1
authored andcommitted
REF: standardize get_indexer/get_indexer_non_unique code (pandas-dev#38648)
1 parent d94b482 commit 19dcd0f

File tree

3 files changed

+31
-54
lines changed

3 files changed

+31
-54
lines changed

pandas/core/groupby/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ def reset_identity(values):
12331233
# so we resort to this
12341234
# GH 14776, 30667
12351235
if ax.has_duplicates and not result.axes[self.axis].equals(ax):
1236-
indexer, _ = result.index.get_indexer_non_unique(ax.values)
1236+
indexer, _ = result.index.get_indexer_non_unique(ax._values)
12371237
indexer = algorithms.unique1d(indexer)
12381238
result = result.take(indexer, axis=self.axis)
12391239
else:

pandas/core/indexes/base.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -2191,7 +2191,7 @@ def _nan_idxs(self):
21912191
if self._can_hold_na:
21922192
return self._isnan.nonzero()[0]
21932193
else:
2194-
return np.array([], dtype=np.int64)
2194+
return np.array([], dtype=np.intp)
21952195

21962196
@cache_readonly
21972197
def hasnans(self) -> bool:
@@ -2728,12 +2728,12 @@ def _union(self, other, sort):
27282728
# find indexes of things in "other" that are not in "self"
27292729
if self.is_unique:
27302730
indexer = self.get_indexer(other)
2731-
indexer = (indexer == -1).nonzero()[0]
2731+
missing = (indexer == -1).nonzero()[0]
27322732
else:
2733-
indexer = algos.unique1d(self.get_indexer_non_unique(other)[1])
2733+
missing = algos.unique1d(self.get_indexer_non_unique(other)[1])
27342734

2735-
if len(indexer) > 0:
2736-
other_diff = algos.take_nd(rvals, indexer, allow_fill=False)
2735+
if len(missing) > 0:
2736+
other_diff = algos.take_nd(rvals, missing, allow_fill=False)
27372737
result = concat_compat((lvals, other_diff))
27382738

27392739
else:
@@ -2838,13 +2838,14 @@ def _intersection(self, other, sort=False):
28382838
return algos.unique1d(result)
28392839

28402840
try:
2841-
indexer = Index(rvals).get_indexer(lvals)
2842-
indexer = indexer.take((indexer != -1).nonzero()[0])
2841+
indexer = other.get_indexer(lvals)
28432842
except (InvalidIndexError, IncompatibleFrequency):
28442843
# InvalidIndexError raised by get_indexer if non-unique
28452844
# IncompatibleFrequency raised by PeriodIndex.get_indexer
2846-
indexer = algos.unique1d(Index(rvals).get_indexer_non_unique(lvals)[0])
2847-
indexer = indexer[indexer != -1]
2845+
indexer, _ = other.get_indexer_non_unique(lvals)
2846+
2847+
mask = indexer != -1
2848+
indexer = indexer.take(mask.nonzero()[0])
28482849

28492850
result = other.take(indexer).unique()._values
28502851

@@ -3526,7 +3527,7 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
35263527
"cannot reindex a non-unique index "
35273528
"with a method or limit"
35283529
)
3529-
indexer, missing = self.get_indexer_non_unique(target)
3530+
indexer, _ = self.get_indexer_non_unique(target)
35303531

35313532
if preserve_names and target.nlevels == 1 and target.name != self.name:
35323533
target = target.copy()
@@ -4919,10 +4920,7 @@ def get_indexer_non_unique(self, target):
49194920
that = target.astype(dtype, copy=False)
49204921
return this.get_indexer_non_unique(that)
49214922

4922-
if is_categorical_dtype(target.dtype):
4923-
tgt_values = np.asarray(target)
4924-
else:
4925-
tgt_values = target._get_engine_target()
4923+
tgt_values = target._get_engine_target()
49264924

49274925
indexer, missing = self._engine.get_indexer_non_unique(tgt_values)
49284926
return ensure_platform_int(indexer), missing

pandas/core/indexes/interval.py

+18-39
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pandas._libs import lib
1212
from pandas._libs.interval import Interval, IntervalMixin, IntervalTree
1313
from pandas._libs.tslibs import BaseOffset, Timedelta, Timestamp, to_offset
14-
from pandas._typing import AnyArrayLike, DtypeObj, Label
14+
from pandas._typing import DtypeObj, Label
1515
from pandas.errors import InvalidIndexError
1616
from pandas.util._decorators import Appender, cache_readonly
1717
from pandas.util._exceptions import rewrite_exception
@@ -652,9 +652,8 @@ def _get_indexer(
652652
if self.equals(target):
653653
return np.arange(len(self), dtype="intp")
654654

655-
if self._is_non_comparable_own_type(target):
656-
# different closed or incompatible subtype -> no matches
657-
return np.repeat(np.intp(-1), len(target))
655+
if not self._should_compare(target):
656+
return self._get_indexer_non_comparable(target, method, unique=True)
658657

659658
# non-overlapping -> at most one match per interval in target
660659
# want exact matches -> need both left/right to match, so defer to
@@ -678,32 +677,22 @@ def _get_indexer(
678677
return ensure_platform_int(indexer)
679678

680679
@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
681-
def get_indexer_non_unique(
682-
self, target: AnyArrayLike
683-
) -> Tuple[np.ndarray, np.ndarray]:
684-
target_as_index = ensure_index(target)
685-
686-
# check that target_as_index IntervalIndex is compatible
687-
if isinstance(target_as_index, IntervalIndex):
688-
689-
if self._is_non_comparable_own_type(target_as_index):
690-
# different closed or incompatible subtype -> no matches
691-
return (
692-
np.repeat(-1, len(target_as_index)),
693-
np.arange(len(target_as_index)),
694-
)
695-
696-
if is_object_dtype(target_as_index) or isinstance(
697-
target_as_index, IntervalIndex
698-
):
699-
# target_as_index might contain intervals: defer elementwise to get_loc
700-
return self._get_indexer_pointwise(target_as_index)
680+
def get_indexer_non_unique(self, target: Index) -> Tuple[np.ndarray, np.ndarray]:
681+
target = ensure_index(target)
682+
683+
if isinstance(target, IntervalIndex) and not self._should_compare(target):
684+
# different closed or incompatible subtype -> no matches
685+
return self._get_indexer_non_comparable(target, None, unique=False)
686+
687+
elif is_object_dtype(target.dtype) or isinstance(target, IntervalIndex):
688+
# target might contain intervals: defer elementwise to get_loc
689+
return self._get_indexer_pointwise(target)
701690

702691
else:
703-
target_as_index = self._maybe_convert_i8(target_as_index)
704-
indexer, missing = self._engine.get_indexer_non_unique(
705-
target_as_index.values
706-
)
692+
# Note: this case behaves differently from other Index subclasses
693+
# because IntervalIndex does partial-int indexing
694+
target = self._maybe_convert_i8(target)
695+
indexer, missing = self._engine.get_indexer_non_unique(target.values)
707696

708697
return ensure_platform_int(indexer), ensure_platform_int(missing)
709698

@@ -789,16 +778,6 @@ def _should_compare(self, other) -> bool:
789778
return False
790779
return other.closed == self.closed
791780

792-
# TODO: use should_compare and get rid of _is_non_comparable_own_type
793-
def _is_non_comparable_own_type(self, other: "IntervalIndex") -> bool:
794-
# different closed or incompatible subtype -> no matches
795-
796-
# TODO: once closed is part of IntervalDtype, we can just define
797-
# is_comparable_dtype GH#19371
798-
if self.closed != other.closed:
799-
return True
800-
return not self._is_comparable_dtype(other.dtype)
801-
802781
# --------------------------------------------------------------------
803782

804783
@cache_readonly
@@ -938,7 +917,7 @@ def _format_space(self) -> str:
938917
def _assert_can_do_setop(self, other):
939918
super()._assert_can_do_setop(other)
940919

941-
if isinstance(other, IntervalIndex) and self._is_non_comparable_own_type(other):
920+
if isinstance(other, IntervalIndex) and not self._should_compare(other):
942921
# GH#19016: ensure set op will not return a prohibited dtype
943922
raise TypeError(
944923
"can only do set operations between two IntervalIndex "

0 commit comments

Comments
 (0)