Skip to content

Commit 044df8c

Browse files
authored
REF: implement _should_compare (#38105)
1 parent b2f6b70 commit 044df8c

File tree

3 files changed

+122
-6
lines changed

3 files changed

+122
-6
lines changed

pandas/core/indexes/base.py

+66
Original file line numberDiff line numberDiff line change
@@ -4941,6 +4941,43 @@ def get_indexer_for(self, target, **kwargs):
49414941
indexer, _ = self.get_indexer_non_unique(target)
49424942
return indexer
49434943

4944+
def _get_indexer_non_comparable(self, target: "Index", method, unique: bool = True):
4945+
"""
4946+
Called from get_indexer or get_indexer_non_unique when the target
4947+
is of a non-comparable dtype.
4948+
4949+
For get_indexer lookups with method=None, get_indexer is an _equality_
4950+
check, so non-comparable dtypes mean we will always have no matches.
4951+
4952+
For get_indexer lookups with a method, get_indexer is an _inequality_
4953+
check, so non-comparable dtypes mean we will always raise TypeError.
4954+
4955+
Parameters
4956+
----------
4957+
target : Index
4958+
method : str or None
4959+
unique : bool, default True
4960+
* True if called from get_indexer.
4961+
* False if called from get_indexer_non_unique.
4962+
4963+
Raises
4964+
------
4965+
TypeError
4966+
If doing an inequality check, i.e. method is not None.
4967+
"""
4968+
if method is not None:
4969+
other = _unpack_nested_dtype(target)
4970+
raise TypeError(f"Cannot compare dtypes {self.dtype} and {other.dtype}")
4971+
4972+
no_matches = -1 * np.ones(target.shape, dtype=np.intp)
4973+
if unique:
4974+
# This is for get_indexer
4975+
return no_matches
4976+
else:
4977+
# This is for get_indexer_non_unique
4978+
missing = np.arange(len(target), dtype=np.intp)
4979+
return no_matches, missing
4980+
49444981
@property
49454982
def _index_as_unique(self):
49464983
"""
@@ -4976,6 +5013,14 @@ def _maybe_promote(self, other: "Index"):
49765013

49775014
return self, other
49785015

5016+
def _should_compare(self, other: "Index") -> bool:
5017+
"""
5018+
Check if `self == other` can ever have non-False entries.
5019+
"""
5020+
other = _unpack_nested_dtype(other)
5021+
dtype = other.dtype
5022+
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
5023+
49795024
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
49805025
"""
49815026
Can we compare values of the given dtype to our own?
@@ -6123,3 +6168,24 @@ def get_unanimous_names(*indexes: Index) -> Tuple[Label, ...]:
61236168
name_sets = [{*ns} for ns in zip_longest(*name_tups)]
61246169
names = tuple(ns.pop() if len(ns) == 1 else None for ns in name_sets)
61256170
return names
6171+
6172+
6173+
def _unpack_nested_dtype(other: Index) -> Index:
6174+
"""
6175+
When checking if our dtype is comparable with another, we need
6176+
to unpack CategoricalDtype to look at its categories.dtype.
6177+
6178+
Parameters
6179+
----------
6180+
other : Index
6181+
6182+
Returns
6183+
-------
6184+
Index
6185+
"""
6186+
dtype = other.dtype
6187+
if is_categorical_dtype(dtype):
6188+
# If there is ever a SparseIndex, this could get dispatched
6189+
# here too.
6190+
return dtype.categories
6191+
return other

pandas/core/indexes/period.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,10 @@ def join(self, other, how="left", level=None, return_indexers=False, sort=False)
452452
def get_indexer(self, target, method=None, limit=None, tolerance=None):
453453
target = ensure_index(target)
454454

455-
if isinstance(target, PeriodIndex):
456-
if not self._is_comparable_dtype(target.dtype):
457-
# i.e. target.freq != self.freq
458-
# No matches
459-
no_matches = -1 * np.ones(self.shape, dtype=np.intp)
460-
return no_matches
455+
if not self._should_compare(target):
456+
return self._get_indexer_non_comparable(target, method, unique=True)
461457

458+
if isinstance(target, PeriodIndex):
462459
target = target._get_engine_target() # i.e. target.asi8
463460
self_index = self._int64index
464461
else:

pandas/tests/indexes/period/test_indexing.py

+53
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@
2121
)
2222
import pandas._testing as tm
2323

24+
dti4 = date_range("2016-01-01", periods=4)
25+
dti = dti4[:-1]
26+
rng = pd.Index(range(3))
27+
28+
29+
@pytest.fixture(
30+
params=[
31+
dti,
32+
dti.tz_localize("UTC"),
33+
dti.to_period("W"),
34+
dti - dti[0],
35+
rng,
36+
pd.Index([1, 2, 3]),
37+
pd.Index([2.0, 3.0, 4.0]),
38+
pd.Index([4, 5, 6], dtype="u8"),
39+
pd.IntervalIndex.from_breaks(dti4),
40+
]
41+
)
42+
def non_comparable_idx(request):
43+
# All have length 3
44+
return request.param
45+
2446

2547
class TestGetItem:
2648
def test_ellipsis(self):
@@ -438,6 +460,37 @@ def test_get_indexer_mismatched_dtype(self):
438460
result = pi.get_indexer_non_unique(pi2)[0]
439461
tm.assert_numpy_array_equal(result, expected)
440462

463+
def test_get_indexer_mismatched_dtype_different_length(self, non_comparable_idx):
464+
# without method we arent checking inequalities, so get all-missing
465+
# but do not raise
466+
dti = date_range("2016-01-01", periods=3)
467+
pi = dti.to_period("D")
468+
469+
other = non_comparable_idx
470+
471+
res = pi[:-1].get_indexer(other)
472+
expected = -np.ones(other.shape, dtype=np.intp)
473+
tm.assert_numpy_array_equal(res, expected)
474+
475+
@pytest.mark.parametrize("method", ["pad", "backfill", "nearest"])
476+
def test_get_indexer_mismatched_dtype_with_method(self, non_comparable_idx, method):
477+
dti = date_range("2016-01-01", periods=3)
478+
pi = dti.to_period("D")
479+
480+
other = non_comparable_idx
481+
482+
msg = re.escape(f"Cannot compare dtypes {pi.dtype} and {other.dtype}")
483+
with pytest.raises(TypeError, match=msg):
484+
pi.get_indexer(other, method=method)
485+
486+
for dtype in ["object", "category"]:
487+
other2 = other.astype(dtype)
488+
if dtype == "object" and isinstance(other, PeriodIndex):
489+
continue
490+
# For object dtype we are liable to get a different exception message
491+
with pytest.raises(TypeError):
492+
pi.get_indexer(other2, method=method)
493+
441494
def test_get_indexer_non_unique(self):
442495
# GH 17717
443496
p1 = Period("2017-09-02")

0 commit comments

Comments
 (0)