diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 3c5421ae433b6..dd4c7c1eedb49 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -1171,3 +1171,15 @@ def __from_arrow__( results.append(iarr) return IntervalArray._concat_same_type(results) + + def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]: + # NB: this doesn't handle checking for closed match + if not all(isinstance(x, IntervalDtype) for x in dtypes): + return np.dtype(object) + + from pandas.core.dtypes.cast import find_common_type + + common = find_common_type([cast("IntervalDtype", x).subtype for x in dtypes]) + if common == object: + return np.dtype(object) + return IntervalDtype(common) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index ee25a9d81a60f..c79dabe9eb1de 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -810,9 +810,11 @@ def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: return not is_object_dtype(common_subtype) def _should_compare(self, other) -> bool: - if not super()._should_compare(other): - return False other = unpack_nested_dtype(other) + if is_object_dtype(other.dtype): + return True + if not self._is_comparable_dtype(other.dtype): + return False return other.closed == self.closed # TODO: use should_compare and get rid of _is_non_comparable_own_type @@ -972,23 +974,6 @@ def _assert_can_do_setop(self, other): "and have compatible dtypes" ) - @Appender(Index.intersection.__doc__) - def intersection(self, other, sort=False) -> Index: - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - if not isinstance(other, IntervalIndex): - return self.astype(object).intersection(other) - - result = self._intersection(other, sort=sort) - return self._wrap_setop_result(other, result) - def _intersection(self, other, sort): """ intersection specialized to the case with matching dtypes. diff --git a/pandas/tests/dtypes/cast/test_find_common_type.py b/pandas/tests/dtypes/cast/test_find_common_type.py index 8dac92f469703..7b1aa12dc0cc4 100644 --- a/pandas/tests/dtypes/cast/test_find_common_type.py +++ b/pandas/tests/dtypes/cast/test_find_common_type.py @@ -2,7 +2,12 @@ import pytest from pandas.core.dtypes.cast import find_common_type -from pandas.core.dtypes.dtypes import CategoricalDtype, DatetimeTZDtype, PeriodDtype +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, + IntervalDtype, + PeriodDtype, +) @pytest.mark.parametrize( @@ -120,3 +125,34 @@ def test_period_dtype_mismatch(dtype2): dtype = PeriodDtype(freq="D") assert find_common_type([dtype, dtype2]) == object assert find_common_type([dtype2, dtype]) == object + + +interval_dtypes = [ + IntervalDtype(np.int64), + IntervalDtype(np.float64), + IntervalDtype(np.uint64), + IntervalDtype(DatetimeTZDtype(unit="ns", tz="US/Eastern")), + IntervalDtype("M8[ns]"), + IntervalDtype("m8[ns]"), +] + + +@pytest.mark.parametrize("left", interval_dtypes) +@pytest.mark.parametrize("right", interval_dtypes) +def test_interval_dtype(left, right): + result = find_common_type([left, right]) + + if left is right: + assert result is left + + elif left.subtype.kind in ["i", "u", "f"]: + # i.e. numeric + if right.subtype.kind in ["i", "u", "f"]: + # both numeric -> common numeric subtype + expected = IntervalDtype(np.float64) + assert result == expected + else: + assert result == object + + else: + assert result == object diff --git a/pandas/tests/indexes/interval/test_setops.py b/pandas/tests/indexes/interval/test_setops.py index 0ef833bb93ded..278c27b302ef9 100644 --- a/pandas/tests/indexes/interval/test_setops.py +++ b/pandas/tests/indexes/interval/test_setops.py @@ -61,17 +61,6 @@ def test_intersection(self, closed, sort): tm.assert_index_equal(index.intersection(index, sort=sort), index) - # GH 19101: empty result, same dtype - other = monotonic_index(300, 314, closed=closed) - expected = empty_index(dtype="int64", closed=closed) - result = index.intersection(other, sort=sort) - tm.assert_index_equal(result, expected) - - # GH 19101: empty result, different dtypes - other = monotonic_index(300, 314, dtype="float64", closed=closed) - result = index.intersection(other, sort=sort) - tm.assert_index_equal(result, expected) - # GH 26225: nested intervals index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)]) other = IntervalIndex.from_tuples([(1, 2), (1, 3)]) @@ -100,6 +89,25 @@ def test_intersection(self, closed, sort): result = index.intersection(other) tm.assert_index_equal(result, expected) + def test_intersection_empty_result(self, closed, sort): + index = monotonic_index(0, 11, closed=closed) + + # GH 19101: empty result, same dtype + other = monotonic_index(300, 314, closed=closed) + expected = empty_index(dtype="int64", closed=closed) + result = index.intersection(other, sort=sort) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different numeric dtypes -> common dtype is float64 + other = monotonic_index(300, 314, dtype="float64", closed=closed) + result = index.intersection(other, sort=sort) + expected = other[:0] + tm.assert_index_equal(result, expected) + + other = monotonic_index(300, 314, dtype="uint64", closed=closed) + result = index.intersection(other, sort=sort) + tm.assert_index_equal(result, expected) + def test_difference(self, closed, sort): index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed) result = index.difference(index[:1], sort=sort)