Skip to content

Commit bf72b70

Browse files
authored
REF: de-duplicate symmetric_difference, _union (#41833)
1 parent 4a00fcc commit bf72b70

File tree

6 files changed

+22
-46
lines changed

6 files changed

+22
-46
lines changed

pandas/core/indexes/base.py

+6-26
Original file line numberDiff line numberDiff line change
@@ -3246,33 +3246,13 @@ def symmetric_difference(self, other, result_name=None, sort=None):
32463246
if result_name is None:
32473247
result_name = result_name_update
32483248

3249-
if not self._should_compare(other):
3250-
return self.union(other, sort=sort).rename(result_name)
3251-
elif not is_dtype_equal(self.dtype, other.dtype):
3252-
dtype = find_common_type([self.dtype, other.dtype])
3253-
this = self.astype(dtype, copy=False)
3254-
that = other.astype(dtype, copy=False)
3255-
return this.symmetric_difference(that, sort=sort).rename(result_name)
3256-
3257-
this = self._get_unique_index()
3258-
other = other._get_unique_index()
3259-
indexer = this.get_indexer_for(other)
3249+
left = self.difference(other, sort=False)
3250+
right = other.difference(self, sort=False)
3251+
result = left.union(right, sort=sort)
32603252

3261-
# {this} minus {other}
3262-
common_indexer = indexer.take((indexer != -1).nonzero()[0])
3263-
left_indexer = np.setdiff1d(
3264-
np.arange(this.size), common_indexer, assume_unique=True
3265-
)
3266-
left_diff = this._values.take(left_indexer)
3267-
3268-
# {other} minus {this}
3269-
right_indexer = (indexer == -1).nonzero()[0]
3270-
right_diff = other._values.take(right_indexer)
3271-
3272-
the_diff = concat_compat([left_diff, right_diff])
3273-
the_diff = _maybe_try_sort(the_diff, sort)
3274-
3275-
return Index(the_diff, name=result_name)
3253+
if result_name is not None:
3254+
result = result.rename(result_name)
3255+
return result
32763256

32773257
@final
32783258
def _assert_can_do_setop(self, other) -> bool:

pandas/core/indexes/category.py

+4
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,10 @@ def _is_dtype_compat(self, other) -> Categorical:
269269
raise TypeError(
270270
"categories must match existing categories when appending"
271271
)
272+
273+
elif other._is_multi:
274+
# preempt raising NotImplementedError in isna call
275+
raise TypeError("MultiIndex is not dtype-compatible with CategoricalIndex")
272276
else:
273277
values = other
274278

pandas/core/indexes/datetimelike.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
inherit_names,
6161
make_wrapped_arith_op,
6262
)
63-
from pandas.core.indexes.numeric import Int64Index
6463
from pandas.core.tools.timedeltas import to_timedelta
6564

6665
if TYPE_CHECKING:
@@ -779,11 +778,7 @@ def _union(self, other, sort):
779778
# that result.freq == self.freq
780779
return result
781780
else:
782-
i8self = Int64Index._simple_new(self.asi8)
783-
i8other = Int64Index._simple_new(other.asi8)
784-
i8result = i8self._union(i8other, sort=sort)
785-
result = type(self)(i8result, dtype=self.dtype, freq="infer")
786-
return result
781+
return super()._union(other, sort=sort)._with_freq("infer")
787782

788783
# --------------------------------------------------------------------
789784
# Join Methods

pandas/core/indexes/multi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3680,9 +3680,9 @@ def symmetric_difference(self, other, result_name=None, sort=None):
36803680
return type(self)(
36813681
levels=[[] for _ in range(self.nlevels)],
36823682
codes=[[] for _ in range(self.nlevels)],
3683-
names=tups.name,
3683+
names=tups.names,
36843684
)
3685-
return type(self).from_tuples(tups, names=tups.name)
3685+
return tups
36863686

36873687
# --------------------------------------------------------------------
36883688

pandas/core/indexes/range.py

-12
Original file line numberDiff line numberDiff line change
@@ -730,18 +730,6 @@ def _difference(self, other, sort=None):
730730
new_index = new_index[::-1]
731731
return new_index
732732

733-
def symmetric_difference(self, other, result_name: Hashable = None, sort=None):
734-
if not isinstance(other, RangeIndex) or sort is not None:
735-
return super().symmetric_difference(other, result_name, sort)
736-
737-
left = self.difference(other)
738-
right = other.difference(self)
739-
result = left.union(right)
740-
741-
if result_name is not None:
742-
result = result.rename(result_name)
743-
return result
744-
745733
# --------------------------------------------------------------------
746734

747735
def _concat(self, indexes: list[Index], name: Hashable) -> Index:

pandas/tests/indexes/categorical/test_equals.py

+9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Categorical,
66
CategoricalIndex,
77
Index,
8+
MultiIndex,
89
)
910

1011

@@ -79,3 +80,11 @@ def test_equals_non_category(self):
7980
other = Index(["A", "B", "D", np.nan])
8081

8182
assert not ci.equals(other)
83+
84+
def test_equals_multiindex(self):
85+
# dont raise NotImplementedError when calling is_dtype_compat
86+
87+
mi = MultiIndex.from_arrays([["A", "B", "C", "D"], range(4)])
88+
ci = mi.to_flat_index().astype("category")
89+
90+
assert not ci.equals(mi)

0 commit comments

Comments
 (0)