diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 65bfd8289fe3d..3b5bc5dbd6c83 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -212,7 +212,7 @@ Strings Interval ^^^^^^^^ - +- Bug in :meth:`IntervalIndex.intersection` always returning object-dtype when intersecting with :class:`CategoricalIndex` (:issue:`38653`) - - @@ -236,6 +236,7 @@ MultiIndex - Bug in :meth:`DataFrame.drop` raising ``TypeError`` when :class:`MultiIndex` is non-unique and no level is provided (:issue:`36293`) - Bug in :meth:`MultiIndex.equals` incorrectly returning ``True`` when :class:`MultiIndex` containing ``NaN`` even when they are differntly ordered (:issue:`38439`) +- Bug in :meth:`MultiIndex.intersection` always returning empty when intersecting with :class:`CategoricalIndex` (:issue:`38653`) I/O ^^^ diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 75f3b511bc57d..4cfba314c719c 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -1173,7 +1173,7 @@ def __from_arrow__( def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]: # NB: this doesn't handle checking for closed match if not all(isinstance(x, IntervalDtype) for x in dtypes): - return np.dtype(object) + return None from pandas.core.dtypes.cast import find_common_type diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index e18b83a22202e..8d48a6277d412 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2814,6 +2814,8 @@ def intersection(self, other, sort=False): elif not self._should_compare(other): # We can infer that the intersection is empty. + if isinstance(self, ABCMultiIndex): + return self[:0].rename(result_name) return Index([], name=result_name) elif not is_dtype_equal(self.dtype, other.dtype): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 6a558c8c3210f..94c055e264e71 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -654,61 +654,16 @@ def difference(self, other, sort=None): new_idx = super().difference(other, sort=sort)._with_freq(None) return new_idx - def intersection(self, other, sort=False): - """ - Specialized intersection for DatetimeIndex/TimedeltaIndex. - - May be much faster than Index.intersection - - Parameters - ---------- - other : Same type as self or array-like - sort : False or None, default False - Sort the resulting index if possible. - - .. versionadded:: 0.24.0 - - .. versionchanged:: 0.24.1 - - Changed the default to ``False`` to match the behaviour - from before 0.24.0. - - .. versionchanged:: 0.25.0 - - The `sort` keyword is added - - Returns - ------- - y : Index or same type as self - """ - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, result_name = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - elif not self._should_compare(other): - # We can infer that the intersection is empty. - return Index([], name=result_name) - - return self._intersection(other, sort=sort) - def _intersection(self, other: Index, sort=False) -> Index: """ intersection specialized to the case with matching dtypes. """ + other = cast("DatetimeTimedeltaMixin", other) if len(self) == 0: return self.copy()._get_reconciled_name_object(other) if len(other) == 0: return other.copy()._get_reconciled_name_object(self) - if not isinstance(other, type(self)): - result = Index.intersection(self, other, sort=sort) - return result - elif not self._can_fast_intersect(other): result = Index._intersection(self, other, sort=sort) # We need to invalidate the freq because Index._intersection diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 5312dfe84cfd8..c7be66b596246 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3449,8 +3449,8 @@ def equals(self, other: object) -> bool: if not isinstance(other, MultiIndex): # d-level MultiIndex can equal d-tuple Index - if not is_object_dtype(other.dtype): - # other cannot contain tuples, so cannot match self + if not self._should_compare(other): + # object Index or Categorical[object] may contain tuples return False return array_equivalent(self._values, other._values) @@ -3588,13 +3588,15 @@ def union(self, other, sort=None): def _union(self, other, sort): other, result_names = self._convert_can_do_setop(other) - if not is_object_dtype(other.dtype): + if not self._should_compare(other): raise NotImplementedError( "Can only union MultiIndex with MultiIndex or Index of tuples, " "try mi.to_flat_index().union(other) instead." ) - uniq_tuples = lib.fast_unique_multiple([self._values, other._values], sort=sort) + # We could get here with CategoricalIndex other + rvals = other._values.astype(object, copy=False) + uniq_tuples = lib.fast_unique_multiple([self._values, rvals], sort=sort) return MultiIndex.from_arrays( zip(*uniq_tuples), sortorder=0, names=result_names @@ -3631,47 +3633,11 @@ def _maybe_match_names(self, other): names.append(None) return names - def intersection(self, other, sort=False): - """ - Form the intersection of two MultiIndex objects. - - Parameters - ---------- - other : MultiIndex or array / Index of tuples - sort : False or None, default False - Sort the resulting MultiIndex if possible - - .. versionadded:: 0.24.0 - - .. versionchanged:: 0.24.1 - - Changed the default from ``True`` to ``False``, to match - behaviour from before 0.24.0 - - Returns - ------- - Index - """ - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - return self._intersection(other, sort=sort) - def _intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) - if not self._is_comparable_dtype(other.dtype): - # The intersection is empty - return self[:0].rename(result_names) - lvals = self._values - rvals = other._values + rvals = other._values.astype(object, copy=False) uniq_tuples = None # flag whether _inner_indexer was successful if self.is_monotonic and other.is_monotonic: diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index d5822d919ae64..4d48bc0d51912 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -632,26 +632,6 @@ def _setop(self, other, sort, opname: str): result = type(self)._simple_new(parr, name=res_name) return result - def intersection(self, other, sort=False): - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, result_name = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - elif not self._should_compare(other): - # We can infer that the intersection is empty. - return Index([], name=result_name) - - elif not is_dtype_equal(self.dtype, other.dtype): - # i.e. object dtype - return super().intersection(other, sort=sort) - - return self._intersection(other, sort=sort) - def _intersection(self, other, sort=False): return self._setop(other, sort, opname="intersection") diff --git a/pandas/tests/dtypes/cast/test_find_common_type.py b/pandas/tests/dtypes/cast/test_find_common_type.py index 7b1aa12dc0cc4..6043deec573f8 100644 --- a/pandas/tests/dtypes/cast/test_find_common_type.py +++ b/pandas/tests/dtypes/cast/test_find_common_type.py @@ -9,6 +9,8 @@ PeriodDtype, ) +from pandas import Categorical, Index + @pytest.mark.parametrize( "source_dtypes,expected_common_dtype", @@ -156,3 +158,13 @@ def test_interval_dtype(left, right): else: assert result == object + + +@pytest.mark.parametrize("dtype", interval_dtypes) +def test_interval_dtype_with_categorical(dtype): + obj = Index([], dtype=dtype) + + cat = Categorical([], categories=obj) + + result = find_common_type([dtype, cat.dtype]) + assert result == dtype diff --git a/pandas/tests/indexes/multi/test_equivalence.py b/pandas/tests/indexes/multi/test_equivalence.py index 52acb3d0797c6..c44f7622c04dd 100644 --- a/pandas/tests/indexes/multi/test_equivalence.py +++ b/pandas/tests/indexes/multi/test_equivalence.py @@ -10,6 +10,8 @@ def test_equals(idx): assert idx.equals(idx) assert idx.equals(idx.copy()) assert idx.equals(idx.astype(object)) + assert idx.equals(idx.to_flat_index()) + assert idx.equals(idx.to_flat_index().astype("category")) assert not idx.equals(list(idx)) assert not idx.equals(np.array(idx)) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index a26eb793afe7e..9a7ff78bae3db 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -294,6 +294,20 @@ def test_intersection(idx, sort): # assert result.equals(tuples) +@pytest.mark.parametrize("method", ["intersection", "union"]) +def test_setop_with_categorical(idx, sort, method): + other = idx.to_flat_index().astype("category") + res_names = [None] * idx.nlevels + + result = getattr(idx, method)(other, sort=sort) + expected = getattr(idx, method)(idx, sort=sort).rename(res_names) + tm.assert_index_equal(result, expected) + + result = getattr(idx, method)(other[:5], sort=sort) + expected = getattr(idx, method)(idx[:5], sort=sort).rename(res_names) + tm.assert_index_equal(result, expected) + + def test_intersection_non_object(idx, sort): other = Index(range(3), name="foo") diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 6f949960ce30b..538e937703de6 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -446,3 +446,20 @@ def test_intersection_difference_match_empty(self, index, sort): inter = index.intersection(index[:0]) diff = index.difference(index, sort=sort) tm.assert_index_equal(inter, diff, exact=True) + + +@pytest.mark.parametrize("method", ["intersection", "union"]) +def test_setop_with_categorical(index, sort, method): + if isinstance(index, MultiIndex): + # tested separately in tests.indexes.multi.test_setops + return + + other = index.astype("category") + + result = getattr(index, method)(other, sort=sort) + expected = getattr(index, method)(index, sort=sort) + tm.assert_index_equal(result, expected) + + result = getattr(index, method)(other[:5], sort=sort) + expected = getattr(index, method)(index[:5], sort=sort) + tm.assert_index_equal(result, expected)