Skip to content

Commit 9ea1031

Browse files
authored
ENH: implement _should_compare/_is_comparable_dtype for all Index subclasses (#38251)
1 parent 69021e3 commit 9ea1031

File tree

6 files changed

+58
-25
lines changed

6 files changed

+58
-25
lines changed

pandas/core/indexes/base.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -4906,16 +4906,31 @@ def get_indexer_non_unique(self, target):
49064906
# Treat boolean labels passed to a numeric index as not found. Without
49074907
# this fix False and True would be treated as 0 and 1 respectively.
49084908
# (GH #16877)
4909-
no_matches = -1 * np.ones(self.shape, dtype=np.intp)
4910-
return no_matches, no_matches
4909+
return self._get_indexer_non_comparable(target, method=None, unique=False)
49114910

49124911
pself, ptarget = self._maybe_promote(target)
49134912
if pself is not self or ptarget is not target:
49144913
return pself.get_indexer_non_unique(ptarget)
49154914

4916-
if not self._is_comparable_dtype(target.dtype):
4917-
no_matches = -1 * np.ones(self.shape, dtype=np.intp)
4918-
return no_matches, no_matches
4915+
if not self._should_compare(target):
4916+
return self._get_indexer_non_comparable(target, method=None, unique=False)
4917+
4918+
if not is_dtype_equal(self.dtype, target.dtype):
4919+
# TODO: if object, could use infer_dtype to pre-empt costly
4920+
# conversion if still non-comparable?
4921+
dtype = find_common_type([self.dtype, target.dtype])
4922+
if (
4923+
dtype.kind in ["i", "u"]
4924+
and is_categorical_dtype(target.dtype)
4925+
and target.hasnans
4926+
):
4927+
# FIXME: find_common_type incorrect with Categorical GH#38240
4928+
# FIXME: some cases where float64 cast can be lossy?
4929+
dtype = np.dtype(np.float64)
4930+
4931+
this = self.astype(dtype, copy=False)
4932+
that = target.astype(dtype, copy=False)
4933+
return this.get_indexer_non_unique(that)
49194934

49204935
if is_categorical_dtype(target.dtype):
49214936
tgt_values = np.asarray(target)
@@ -4968,7 +4983,7 @@ def _get_indexer_non_comparable(self, target: "Index", method, unique: bool = Tr
49684983
If doing an inequality check, i.e. method is not None.
49694984
"""
49704985
if method is not None:
4971-
other = _unpack_nested_dtype(target)
4986+
other = unpack_nested_dtype(target)
49724987
raise TypeError(f"Cannot compare dtypes {self.dtype} and {other.dtype}")
49734988

49744989
no_matches = -1 * np.ones(target.shape, dtype=np.intp)
@@ -5019,7 +5034,7 @@ def _should_compare(self, other: "Index") -> bool:
50195034
"""
50205035
Check if `self == other` can ever have non-False entries.
50215036
"""
5022-
other = _unpack_nested_dtype(other)
5037+
other = unpack_nested_dtype(other)
50235038
dtype = other.dtype
50245039
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
50255040

@@ -6172,7 +6187,7 @@ def get_unanimous_names(*indexes: Index) -> Tuple[Label, ...]:
61726187
return names
61736188

61746189

6175-
def _unpack_nested_dtype(other: Index) -> Index:
6190+
def unpack_nested_dtype(other: Index) -> Index:
61766191
"""
61776192
When checking if our dtype is comparable with another, we need
61786193
to unpack CategoricalDtype to look at its categories.dtype.

pandas/core/indexes/category.py

+3
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,9 @@ def _maybe_cast_slice_bound(self, label, side: str, kind):
554554

555555
# --------------------------------------------------------------------
556556

557+
def _is_comparable_dtype(self, dtype):
558+
return self.categories._is_comparable_dtype(dtype)
559+
557560
def take_nd(self, *args, **kwargs):
558561
"""Alias for `take`"""
559562
warnings.warn(

pandas/core/indexes/interval.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pandas._libs import lib
1212
from pandas._libs.interval import Interval, IntervalMixin, IntervalTree
1313
from pandas._libs.tslibs import BaseOffset, Timedelta, Timestamp, to_offset
14-
from pandas._typing import AnyArrayLike, Label
14+
from pandas._typing import AnyArrayLike, DtypeObj, Label
1515
from pandas.errors import InvalidIndexError
1616
from pandas.util._decorators import Appender, Substitution, cache_readonly
1717
from pandas.util._exceptions import rewrite_exception
@@ -38,6 +38,7 @@
3838
is_object_dtype,
3939
is_scalar,
4040
)
41+
from pandas.core.dtypes.dtypes import IntervalDtype
4142

4243
from pandas.core.algorithms import take_1d
4344
from pandas.core.arrays.interval import IntervalArray, _interval_shared_docs
@@ -50,6 +51,7 @@
5051
default_pprint,
5152
ensure_index,
5253
maybe_extract_name,
54+
unpack_nested_dtype,
5355
)
5456
from pandas.core.indexes.datetimes import DatetimeIndex, date_range
5557
from pandas.core.indexes.extension import ExtensionIndex, inherit_names
@@ -803,15 +805,27 @@ def _convert_list_indexer(self, keyarr):
803805

804806
return locs
805807

808+
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
809+
if not isinstance(dtype, IntervalDtype):
810+
return False
811+
common_subtype = find_common_type([self.dtype.subtype, dtype.subtype])
812+
return not is_object_dtype(common_subtype)
813+
814+
def _should_compare(self, other) -> bool:
815+
if not super()._should_compare(other):
816+
return False
817+
other = unpack_nested_dtype(other)
818+
return other.closed == self.closed
819+
820+
# TODO: use should_compare and get rid of _is_non_comparable_own_type
806821
def _is_non_comparable_own_type(self, other: "IntervalIndex") -> bool:
807822
# different closed or incompatible subtype -> no matches
808823

809824
# TODO: once closed is part of IntervalDtype, we can just define
810825
# is_comparable_dtype GH#19371
811826
if self.closed != other.closed:
812827
return True
813-
common_subtype = find_common_type([self.dtype.subtype, other.dtype.subtype])
814-
return is_object_dtype(common_subtype)
828+
return not self._is_comparable_dtype(other.dtype)
815829

816830
# --------------------------------------------------------------------
817831

pandas/core/indexes/multi.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pandas._libs import algos as libalgos, index as libindex, lib
2222
from pandas._libs.hashtable import duplicated_int64
23-
from pandas._typing import AnyArrayLike, Label, Scalar, Shape
23+
from pandas._typing import AnyArrayLike, DtypeObj, Label, Scalar, Shape
2424
from pandas.compat.numpy import function as nv
2525
from pandas.errors import InvalidIndexError, PerformanceWarning, UnsortedIndexError
2626
from pandas.util._decorators import Appender, cache_readonly, doc
@@ -3583,6 +3583,9 @@ def union(self, other, sort=None):
35833583
zip(*uniq_tuples), sortorder=0, names=result_names
35843584
)
35853585

3586+
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
3587+
return is_object_dtype(dtype)
3588+
35863589
def intersection(self, other, sort=False):
35873590
"""
35883591
Form the intersection of two MultiIndex objects.
@@ -3618,15 +3621,9 @@ def intersection(self, other, sort=False):
36183621
def _intersection(self, other, sort=False):
36193622
other, result_names = self._convert_can_do_setop(other)
36203623

3621-
if not is_object_dtype(other.dtype):
3624+
if not self._is_comparable_dtype(other.dtype):
36223625
# The intersection is empty
3623-
# TODO: we have no tests that get here
3624-
return MultiIndex(
3625-
levels=self.levels,
3626-
codes=[[]] * self.nlevels,
3627-
names=result_names,
3628-
verify_integrity=False,
3629-
)
3626+
return self[:0].rename(result_names)
36303627

36313628
lvals = self._values
36323629
rvals = other._values

pandas/core/indexes/numeric.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from pandas._libs import index as libindex, lib
7-
from pandas._typing import Dtype, Label
7+
from pandas._typing import Dtype, DtypeObj, Label
88
from pandas.util._decorators import doc
99

1010
from pandas.core.dtypes.cast import astype_nansafe
@@ -148,6 +148,10 @@ def _convert_tolerance(self, tolerance, target):
148148
)
149149
return tolerance
150150

151+
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
152+
# If we ever have BoolIndex or ComplexIndex, this may need to be tightened
153+
return is_numeric_dtype(dtype)
154+
151155
@classmethod
152156
def _assert_safe_casting(cls, data, subarr):
153157
"""

pandas/tests/indexes/test_base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1251,10 +1251,9 @@ def test_get_indexer_numeric_index_boolean_target(self, method, idx_class):
12511251
if method == "get_indexer":
12521252
tm.assert_numpy_array_equal(result, expected)
12531253
else:
1254-
expected = np.array([-1, -1, -1, -1], dtype=np.intp)
1255-
1254+
missing = np.arange(3, dtype=np.intp)
12561255
tm.assert_numpy_array_equal(result[0], expected)
1257-
tm.assert_numpy_array_equal(result[1], expected)
1256+
tm.assert_numpy_array_equal(result[1], missing)
12581257

12591258
def test_get_indexer_with_NA_values(
12601259
self, unique_nulls_fixture, unique_nulls_fixture2
@@ -2359,5 +2358,6 @@ def construct(dtype):
23592358

23602359
else:
23612360
no_matches = np.array([-1] * 6, dtype=np.intp)
2361+
missing = np.arange(6, dtype=np.intp)
23622362
tm.assert_numpy_array_equal(result[0], no_matches)
2363-
tm.assert_numpy_array_equal(result[1], no_matches)
2363+
tm.assert_numpy_array_equal(result[1], missing)

0 commit comments

Comments
 (0)