Skip to content

Commit 53fa89f

Browse files
jbrockmendelluckyvs1
authored andcommitted
REF: share IntervalIndex.intersection with Index.intersection (pandas-dev#38373)
1 parent 52d5e74 commit 53fa89f

File tree

4 files changed

+72
-31
lines changed

4 files changed

+72
-31
lines changed

pandas/core/dtypes/dtypes.py

+12
Original file line numberDiff line numberDiff line change
@@ -1171,3 +1171,15 @@ def __from_arrow__(
11711171
results.append(iarr)
11721172

11731173
return IntervalArray._concat_same_type(results)
1174+
1175+
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
1176+
# NB: this doesn't handle checking for closed match
1177+
if not all(isinstance(x, IntervalDtype) for x in dtypes):
1178+
return np.dtype(object)
1179+
1180+
from pandas.core.dtypes.cast import find_common_type
1181+
1182+
common = find_common_type([cast("IntervalDtype", x).subtype for x in dtypes])
1183+
if common == object:
1184+
return np.dtype(object)
1185+
return IntervalDtype(common)

pandas/core/indexes/interval.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -789,9 +789,11 @@ def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
789789
return not is_object_dtype(common_subtype)
790790

791791
def _should_compare(self, other) -> bool:
792-
if not super()._should_compare(other):
793-
return False
794792
other = unpack_nested_dtype(other)
793+
if is_object_dtype(other.dtype):
794+
return True
795+
if not self._is_comparable_dtype(other.dtype):
796+
return False
795797
return other.closed == self.closed
796798

797799
# TODO: use should_compare and get rid of _is_non_comparable_own_type
@@ -951,23 +953,6 @@ def _assert_can_do_setop(self, other):
951953
"and have compatible dtypes"
952954
)
953955

954-
@Appender(Index.intersection.__doc__)
955-
def intersection(self, other, sort=False) -> Index:
956-
self._validate_sort_keyword(sort)
957-
self._assert_can_do_setop(other)
958-
other, _ = self._convert_can_do_setop(other)
959-
960-
if self.equals(other):
961-
if self.has_duplicates:
962-
return self.unique()._get_reconciled_name_object(other)
963-
return self._get_reconciled_name_object(other)
964-
965-
if not isinstance(other, IntervalIndex):
966-
return self.astype(object).intersection(other)
967-
968-
result = self._intersection(other, sort=sort)
969-
return self._wrap_setop_result(other, result)
970-
971956
def _intersection(self, other, sort):
972957
"""
973958
intersection specialized to the case with matching dtypes.

pandas/tests/dtypes/cast/test_find_common_type.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import pytest
33

44
from pandas.core.dtypes.cast import find_common_type
5-
from pandas.core.dtypes.dtypes import CategoricalDtype, DatetimeTZDtype, PeriodDtype
5+
from pandas.core.dtypes.dtypes import (
6+
CategoricalDtype,
7+
DatetimeTZDtype,
8+
IntervalDtype,
9+
PeriodDtype,
10+
)
611

712

813
@pytest.mark.parametrize(
@@ -120,3 +125,34 @@ def test_period_dtype_mismatch(dtype2):
120125
dtype = PeriodDtype(freq="D")
121126
assert find_common_type([dtype, dtype2]) == object
122127
assert find_common_type([dtype2, dtype]) == object
128+
129+
130+
interval_dtypes = [
131+
IntervalDtype(np.int64),
132+
IntervalDtype(np.float64),
133+
IntervalDtype(np.uint64),
134+
IntervalDtype(DatetimeTZDtype(unit="ns", tz="US/Eastern")),
135+
IntervalDtype("M8[ns]"),
136+
IntervalDtype("m8[ns]"),
137+
]
138+
139+
140+
@pytest.mark.parametrize("left", interval_dtypes)
141+
@pytest.mark.parametrize("right", interval_dtypes)
142+
def test_interval_dtype(left, right):
143+
result = find_common_type([left, right])
144+
145+
if left is right:
146+
assert result is left
147+
148+
elif left.subtype.kind in ["i", "u", "f"]:
149+
# i.e. numeric
150+
if right.subtype.kind in ["i", "u", "f"]:
151+
# both numeric -> common numeric subtype
152+
expected = IntervalDtype(np.float64)
153+
assert result == expected
154+
else:
155+
assert result == object
156+
157+
else:
158+
assert result == object

pandas/tests/indexes/interval/test_setops.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,6 @@ def test_intersection(self, closed, sort):
6161

6262
tm.assert_index_equal(index.intersection(index, sort=sort), index)
6363

64-
# GH 19101: empty result, same dtype
65-
other = monotonic_index(300, 314, closed=closed)
66-
expected = empty_index(dtype="int64", closed=closed)
67-
result = index.intersection(other, sort=sort)
68-
tm.assert_index_equal(result, expected)
69-
70-
# GH 19101: empty result, different dtypes
71-
other = monotonic_index(300, 314, dtype="float64", closed=closed)
72-
result = index.intersection(other, sort=sort)
73-
tm.assert_index_equal(result, expected)
74-
7564
# GH 26225: nested intervals
7665
index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)])
7766
other = IntervalIndex.from_tuples([(1, 2), (1, 3)])
@@ -100,6 +89,25 @@ def test_intersection(self, closed, sort):
10089
result = index.intersection(other)
10190
tm.assert_index_equal(result, expected)
10291

92+
def test_intersection_empty_result(self, closed, sort):
93+
index = monotonic_index(0, 11, closed=closed)
94+
95+
# GH 19101: empty result, same dtype
96+
other = monotonic_index(300, 314, closed=closed)
97+
expected = empty_index(dtype="int64", closed=closed)
98+
result = index.intersection(other, sort=sort)
99+
tm.assert_index_equal(result, expected)
100+
101+
# GH 19101: empty result, different numeric dtypes -> common dtype is float64
102+
other = monotonic_index(300, 314, dtype="float64", closed=closed)
103+
result = index.intersection(other, sort=sort)
104+
expected = other[:0]
105+
tm.assert_index_equal(result, expected)
106+
107+
other = monotonic_index(300, 314, dtype="uint64", closed=closed)
108+
result = index.intersection(other, sort=sort)
109+
tm.assert_index_equal(result, expected)
110+
103111
def test_difference(self, closed, sort):
104112
index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed)
105113
result = index.difference(index[:1], sort=sort)

0 commit comments

Comments
 (0)