Skip to content

Commit 309cf3a

Browse files
authored
REF: standardize get_indexer/get_indexer_non_unique (#39343)
1 parent 9ed521e commit 309cf3a

File tree

4 files changed

+65
-15
lines changed

4 files changed

+65
-15
lines changed

pandas/core/indexes/base.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -3289,11 +3289,10 @@ def get_indexer(
32893289
if not self._index_as_unique:
32903290
raise InvalidIndexError(self._requires_unique_msg)
32913291

3292-
# Treat boolean labels passed to a numeric index as not found. Without
3293-
# this fix False and True would be treated as 0 and 1 respectively.
3294-
# (GH #16877)
3295-
if target.is_boolean() and self.is_numeric():
3296-
return ensure_platform_int(np.repeat(-1, target.size))
3292+
if not self._should_compare(target) and not is_interval_dtype(self.dtype):
3293+
# IntervalIndex get special treatment bc numeric scalars can be
3294+
# matched to Interval scalars
3295+
return self._get_indexer_non_comparable(target, method=method, unique=True)
32973296

32983297
pself, ptarget = self._maybe_promote(target)
32993298
if pself is not self or ptarget is not target:
@@ -3310,8 +3309,9 @@ def _get_indexer(
33103309
tolerance = self._convert_tolerance(tolerance, target)
33113310

33123311
if not is_dtype_equal(self.dtype, target.dtype):
3313-
this = self.astype(object)
3314-
target = target.astype(object)
3312+
dtype = find_common_type([self.dtype, target.dtype])
3313+
this = self.astype(dtype, copy=False)
3314+
target = target.astype(dtype, copy=False)
33153315
return this.get_indexer(
33163316
target, method=method, limit=limit, tolerance=tolerance
33173317
)
@@ -5060,19 +5060,15 @@ def set_value(self, arr, key, value):
50605060
def get_indexer_non_unique(self, target):
50615061
target = ensure_index(target)
50625062

5063-
if target.is_boolean() and self.is_numeric():
5064-
# Treat boolean labels passed to a numeric index as not found. Without
5065-
# this fix False and True would be treated as 0 and 1 respectively.
5066-
# (GH #16877)
5063+
if not self._should_compare(target) and not is_interval_dtype(self.dtype):
5064+
# IntervalIndex get special treatment bc numeric scalars can be
5065+
# matched to Interval scalars
50675066
return self._get_indexer_non_comparable(target, method=None, unique=False)
50685067

50695068
pself, ptarget = self._maybe_promote(target)
50705069
if pself is not self or ptarget is not target:
50715070
return pself.get_indexer_non_unique(ptarget)
50725071

5073-
if not self._should_compare(target):
5074-
return self._get_indexer_non_comparable(target, method=None, unique=False)
5075-
50765072
if not is_dtype_equal(self.dtype, target.dtype):
50775073
# TODO: if object, could use infer_dtype to preempt costly
50785074
# conversion if still non-comparable?
@@ -5193,6 +5189,15 @@ def _should_compare(self, other: Index) -> bool:
51935189
"""
51945190
Check if `self == other` can ever have non-False entries.
51955191
"""
5192+
5193+
if (other.is_boolean() and self.is_numeric()) or (
5194+
self.is_boolean() and other.is_numeric()
5195+
):
5196+
# GH#16877 Treat boolean labels passed to a numeric index as not
5197+
# found. Without this fix False and True would be treated as 0 and 1
5198+
# respectively.
5199+
return False
5200+
51965201
other = unpack_nested_dtype(other)
51975202
dtype = other.dtype
51985203
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)

pandas/tests/indexes/numeric/test_indexing.py

+31
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,37 @@ def test_get_loc_missing_nan(self):
8282

8383

8484
class TestGetIndexer:
85+
@pytest.mark.parametrize("method", ["pad", "backfill", "nearest"])
86+
def test_get_indexer_with_method_numeric_vs_bool(self, method):
87+
left = Index([1, 2, 3])
88+
right = Index([True, False])
89+
90+
with pytest.raises(TypeError, match="Cannot compare"):
91+
left.get_indexer(right, method=method)
92+
93+
with pytest.raises(TypeError, match="Cannot compare"):
94+
right.get_indexer(left, method=method)
95+
96+
def test_get_indexer_numeric_vs_bool(self):
97+
left = Index([1, 2, 3])
98+
right = Index([True, False])
99+
100+
res = left.get_indexer(right)
101+
expected = -1 * np.ones(len(right), dtype=np.intp)
102+
tm.assert_numpy_array_equal(res, expected)
103+
104+
res = right.get_indexer(left)
105+
expected = -1 * np.ones(len(left), dtype=np.intp)
106+
tm.assert_numpy_array_equal(res, expected)
107+
108+
res = left.get_indexer_non_unique(right)[0]
109+
expected = -1 * np.ones(len(right), dtype=np.intp)
110+
tm.assert_numpy_array_equal(res, expected)
111+
112+
res = right.get_indexer_non_unique(left)[0]
113+
expected = -1 * np.ones(len(left), dtype=np.intp)
114+
tm.assert_numpy_array_equal(res, expected)
115+
85116
def test_get_indexer_float64(self):
86117
idx = Float64Index([0.0, 1.0, 2.0])
87118
tm.assert_numpy_array_equal(

pandas/tests/indexes/test_base.py

+11
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,17 @@ def test_asof(self, index):
556556
d = index[0].to_pydatetime()
557557
assert isinstance(index.asof(d), Timestamp)
558558

559+
def test_asof_numeric_vs_bool_raises(self):
560+
left = Index([1, 2, 3])
561+
right = Index([True, False])
562+
563+
msg = "'<' not supported between instances"
564+
with pytest.raises(TypeError, match=msg):
565+
left.asof(right)
566+
567+
with pytest.raises(TypeError, match=msg):
568+
right.asof(left)
569+
559570
def test_asof_datetime_partial(self):
560571
index = date_range("2010-01-01", periods=2, freq="m")
561572
expected = Timestamp("2010-02-28")

pandas/tests/series/methods/test_reindex.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,10 @@ def test_reindex_datetimeindexes_tz_naive_and_aware():
293293
idx = date_range("20131101", tz="America/Chicago", periods=7)
294294
newidx = date_range("20131103", periods=10, freq="H")
295295
s = Series(range(7), index=idx)
296-
msg = "Cannot compare tz-naive and tz-aware timestamps"
296+
msg = (
297+
r"Cannot compare dtypes datetime64\[ns, America/Chicago\] "
298+
r"and datetime64\[ns\]"
299+
)
297300
with pytest.raises(TypeError, match=msg):
298301
s.reindex(newidx, method="ffill")
299302

0 commit comments

Comments
 (0)