Skip to content

Commit 05dca24

Browse files
jbrockmendelluckyvs1
authored andcommitted
BUG: MultiIndex, IntervalIndex intersection with Categorical (pandas-dev#38653)
1 parent 7e1799b commit 05dca24

File tree

10 files changed

+58
-109
lines changed

10 files changed

+58
-109
lines changed

doc/source/whatsnew/v1.3.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ Strings
214214

215215
Interval
216216
^^^^^^^^
217-
217+
- Bug in :meth:`IntervalIndex.intersection` always returning object-dtype when intersecting with :class:`CategoricalIndex` (:issue:`38653`)
218218
-
219219
-
220220

@@ -238,6 +238,7 @@ MultiIndex
238238

239239
- Bug in :meth:`DataFrame.drop` raising ``TypeError`` when :class:`MultiIndex` is non-unique and no level is provided (:issue:`36293`)
240240
- Bug in :meth:`MultiIndex.equals` incorrectly returning ``True`` when :class:`MultiIndex` containing ``NaN`` even when they are differntly ordered (:issue:`38439`)
241+
- Bug in :meth:`MultiIndex.intersection` always returning empty when intersecting with :class:`CategoricalIndex` (:issue:`38653`)
241242

242243
I/O
243244
^^^

pandas/core/dtypes/dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ def __from_arrow__(
11731173
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
11741174
# NB: this doesn't handle checking for closed match
11751175
if not all(isinstance(x, IntervalDtype) for x in dtypes):
1176-
return np.dtype(object)
1176+
return None
11771177

11781178
from pandas.core.dtypes.cast import find_common_type
11791179

pandas/core/indexes/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -2814,6 +2814,8 @@ def intersection(self, other, sort=False):
28142814

28152815
elif not self._should_compare(other):
28162816
# We can infer that the intersection is empty.
2817+
if isinstance(self, ABCMultiIndex):
2818+
return self[:0].rename(result_name)
28172819
return Index([], name=result_name)
28182820

28192821
elif not is_dtype_equal(self.dtype, other.dtype):

pandas/core/indexes/datetimelike.py

+1-46
Original file line numberDiff line numberDiff line change
@@ -654,61 +654,16 @@ def difference(self, other, sort=None):
654654
new_idx = super().difference(other, sort=sort)._with_freq(None)
655655
return new_idx
656656

657-
def intersection(self, other, sort=False):
658-
"""
659-
Specialized intersection for DatetimeIndex/TimedeltaIndex.
660-
661-
May be much faster than Index.intersection
662-
663-
Parameters
664-
----------
665-
other : Same type as self or array-like
666-
sort : False or None, default False
667-
Sort the resulting index if possible.
668-
669-
.. versionadded:: 0.24.0
670-
671-
.. versionchanged:: 0.24.1
672-
673-
Changed the default to ``False`` to match the behaviour
674-
from before 0.24.0.
675-
676-
.. versionchanged:: 0.25.0
677-
678-
The `sort` keyword is added
679-
680-
Returns
681-
-------
682-
y : Index or same type as self
683-
"""
684-
self._validate_sort_keyword(sort)
685-
self._assert_can_do_setop(other)
686-
other, result_name = self._convert_can_do_setop(other)
687-
688-
if self.equals(other):
689-
if self.has_duplicates:
690-
return self.unique()._get_reconciled_name_object(other)
691-
return self._get_reconciled_name_object(other)
692-
693-
elif not self._should_compare(other):
694-
# We can infer that the intersection is empty.
695-
return Index([], name=result_name)
696-
697-
return self._intersection(other, sort=sort)
698-
699657
def _intersection(self, other: Index, sort=False) -> Index:
700658
"""
701659
intersection specialized to the case with matching dtypes.
702660
"""
661+
other = cast("DatetimeTimedeltaMixin", other)
703662
if len(self) == 0:
704663
return self.copy()._get_reconciled_name_object(other)
705664
if len(other) == 0:
706665
return other.copy()._get_reconciled_name_object(self)
707666

708-
if not isinstance(other, type(self)):
709-
result = Index.intersection(self, other, sort=sort)
710-
return result
711-
712667
elif not self._can_fast_intersect(other):
713668
result = Index._intersection(self, other, sort=sort)
714669
# We need to invalidate the freq because Index._intersection

pandas/core/indexes/multi.py

+7-41
Original file line numberDiff line numberDiff line change
@@ -3449,8 +3449,8 @@ def equals(self, other: object) -> bool:
34493449

34503450
if not isinstance(other, MultiIndex):
34513451
# d-level MultiIndex can equal d-tuple Index
3452-
if not is_object_dtype(other.dtype):
3453-
# other cannot contain tuples, so cannot match self
3452+
if not self._should_compare(other):
3453+
# object Index or Categorical[object] may contain tuples
34543454
return False
34553455
return array_equivalent(self._values, other._values)
34563456

@@ -3588,13 +3588,15 @@ def union(self, other, sort=None):
35883588
def _union(self, other, sort):
35893589
other, result_names = self._convert_can_do_setop(other)
35903590

3591-
if not is_object_dtype(other.dtype):
3591+
if not self._should_compare(other):
35923592
raise NotImplementedError(
35933593
"Can only union MultiIndex with MultiIndex or Index of tuples, "
35943594
"try mi.to_flat_index().union(other) instead."
35953595
)
35963596

3597-
uniq_tuples = lib.fast_unique_multiple([self._values, other._values], sort=sort)
3597+
# We could get here with CategoricalIndex other
3598+
rvals = other._values.astype(object, copy=False)
3599+
uniq_tuples = lib.fast_unique_multiple([self._values, rvals], sort=sort)
35983600

35993601
return MultiIndex.from_arrays(
36003602
zip(*uniq_tuples), sortorder=0, names=result_names
@@ -3631,47 +3633,11 @@ def _maybe_match_names(self, other):
36313633
names.append(None)
36323634
return names
36333635

3634-
def intersection(self, other, sort=False):
3635-
"""
3636-
Form the intersection of two MultiIndex objects.
3637-
3638-
Parameters
3639-
----------
3640-
other : MultiIndex or array / Index of tuples
3641-
sort : False or None, default False
3642-
Sort the resulting MultiIndex if possible
3643-
3644-
.. versionadded:: 0.24.0
3645-
3646-
.. versionchanged:: 0.24.1
3647-
3648-
Changed the default from ``True`` to ``False``, to match
3649-
behaviour from before 0.24.0
3650-
3651-
Returns
3652-
-------
3653-
Index
3654-
"""
3655-
self._validate_sort_keyword(sort)
3656-
self._assert_can_do_setop(other)
3657-
other, _ = self._convert_can_do_setop(other)
3658-
3659-
if self.equals(other):
3660-
if self.has_duplicates:
3661-
return self.unique()._get_reconciled_name_object(other)
3662-
return self._get_reconciled_name_object(other)
3663-
3664-
return self._intersection(other, sort=sort)
3665-
36663636
def _intersection(self, other, sort=False):
36673637
other, result_names = self._convert_can_do_setop(other)
36683638

3669-
if not self._is_comparable_dtype(other.dtype):
3670-
# The intersection is empty
3671-
return self[:0].rename(result_names)
3672-
36733639
lvals = self._values
3674-
rvals = other._values
3640+
rvals = other._values.astype(object, copy=False)
36753641

36763642
uniq_tuples = None # flag whether _inner_indexer was successful
36773643
if self.is_monotonic and other.is_monotonic:

pandas/core/indexes/period.py

-20
Original file line numberDiff line numberDiff line change
@@ -632,26 +632,6 @@ def _setop(self, other, sort, opname: str):
632632
result = type(self)._simple_new(parr, name=res_name)
633633
return result
634634

635-
def intersection(self, other, sort=False):
636-
self._validate_sort_keyword(sort)
637-
self._assert_can_do_setop(other)
638-
other, result_name = self._convert_can_do_setop(other)
639-
640-
if self.equals(other):
641-
if self.has_duplicates:
642-
return self.unique()._get_reconciled_name_object(other)
643-
return self._get_reconciled_name_object(other)
644-
645-
elif not self._should_compare(other):
646-
# We can infer that the intersection is empty.
647-
return Index([], name=result_name)
648-
649-
elif not is_dtype_equal(self.dtype, other.dtype):
650-
# i.e. object dtype
651-
return super().intersection(other, sort=sort)
652-
653-
return self._intersection(other, sort=sort)
654-
655635
def _intersection(self, other, sort=False):
656636
return self._setop(other, sort, opname="intersection")
657637

pandas/tests/dtypes/cast/test_find_common_type.py

+12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
PeriodDtype,
1010
)
1111

12+
from pandas import Categorical, Index
13+
1214

1315
@pytest.mark.parametrize(
1416
"source_dtypes,expected_common_dtype",
@@ -156,3 +158,13 @@ def test_interval_dtype(left, right):
156158

157159
else:
158160
assert result == object
161+
162+
163+
@pytest.mark.parametrize("dtype", interval_dtypes)
164+
def test_interval_dtype_with_categorical(dtype):
165+
obj = Index([], dtype=dtype)
166+
167+
cat = Categorical([], categories=obj)
168+
169+
result = find_common_type([dtype, cat.dtype])
170+
assert result == dtype

pandas/tests/indexes/multi/test_equivalence.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ def test_equals(idx):
1010
assert idx.equals(idx)
1111
assert idx.equals(idx.copy())
1212
assert idx.equals(idx.astype(object))
13+
assert idx.equals(idx.to_flat_index())
14+
assert idx.equals(idx.to_flat_index().astype("category"))
1315

1416
assert not idx.equals(list(idx))
1517
assert not idx.equals(np.array(idx))

pandas/tests/indexes/multi/test_setops.py

+14
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,20 @@ def test_intersection(idx, sort):
294294
# assert result.equals(tuples)
295295

296296

297+
@pytest.mark.parametrize("method", ["intersection", "union"])
298+
def test_setop_with_categorical(idx, sort, method):
299+
other = idx.to_flat_index().astype("category")
300+
res_names = [None] * idx.nlevels
301+
302+
result = getattr(idx, method)(other, sort=sort)
303+
expected = getattr(idx, method)(idx, sort=sort).rename(res_names)
304+
tm.assert_index_equal(result, expected)
305+
306+
result = getattr(idx, method)(other[:5], sort=sort)
307+
expected = getattr(idx, method)(idx[:5], sort=sort).rename(res_names)
308+
tm.assert_index_equal(result, expected)
309+
310+
297311
def test_intersection_non_object(idx, sort):
298312
other = Index(range(3), name="foo")
299313

pandas/tests/indexes/test_setops.py

+17
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,20 @@ def test_intersection_difference_match_empty(self, index, sort):
446446
inter = index.intersection(index[:0])
447447
diff = index.difference(index, sort=sort)
448448
tm.assert_index_equal(inter, diff, exact=True)
449+
450+
451+
@pytest.mark.parametrize("method", ["intersection", "union"])
452+
def test_setop_with_categorical(index, sort, method):
453+
if isinstance(index, MultiIndex):
454+
# tested separately in tests.indexes.multi.test_setops
455+
return
456+
457+
other = index.astype("category")
458+
459+
result = getattr(index, method)(other, sort=sort)
460+
expected = getattr(index, method)(index, sort=sort)
461+
tm.assert_index_equal(result, expected)
462+
463+
result = getattr(index, method)(other[:5], sort=sort)
464+
expected = getattr(index, method)(index[:5], sort=sort)
465+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)