Skip to content

Commit b20b009

Browse files
jbrockmendelluckyvs1
authored andcommitted
BUG: IntervalIndex, PeriodIndex, DatetimeIndex symmetric_difference with Categorical (pandas-dev#38741)
1 parent 56a6d15 commit b20b009

File tree

6 files changed

+24
-22
lines changed

6 files changed

+24
-22
lines changed

doc/source/whatsnew/v1.3.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Datetimelike
190190
- Bug in :class:`DataFrame` and :class:`Series` constructors sometimes dropping nanoseconds from :class:`Timestamp` (resp. :class:`Timedelta`) ``data``, with ``dtype=datetime64[ns]`` (resp. ``timedelta64[ns]``) (:issue:`38032`)
191191
- Bug in :meth:`DataFrame.first` and :meth:`Series.first` returning two months for offset one month when first day is last calendar day (:issue:`29623`)
192192
- Bug in constructing a :class:`DataFrame` or :class:`Series` with mismatched ``datetime64`` data and ``timedelta64`` dtype, or vice-versa, failing to raise ``TypeError`` (:issue:`38575`)
193+
- Bug in :meth:`DatetimeIndex.intersection`, :meth:`DatetimeIndex.symmetric_difference`, :meth:`PeriodIndex.intersection`, :meth:`PeriodIndex.symmetric_difference` always returning object-dtype when operating with :class:`CategoricalIndex` (:issue:`38741`)
193194

194195
Timedelta
195196
^^^^^^^^^
@@ -221,7 +222,7 @@ Strings
221222

222223
Interval
223224
^^^^^^^^
224-
- Bug in :meth:`IntervalIndex.intersection` always returning object-dtype when intersecting with :class:`CategoricalIndex` (:issue:`38653`)
225+
- Bug in :meth:`IntervalIndex.intersection` and :meth:`IntervalIndex.symmetric_difference` always returning object-dtype when operating with :class:`CategoricalIndex` (:issue:`38653`, :issue:`38741`)
225226
-
226227
-
227228

pandas/core/indexes/base.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -2606,6 +2606,7 @@ def _validate_sort_keyword(self, sort):
26062606
f"None or False; {sort} was passed."
26072607
)
26082608

2609+
@final
26092610
def union(self, other, sort=None):
26102611
"""
26112612
Form the union of two Index objects.
@@ -2818,6 +2819,7 @@ def _wrap_setop_result(self, other, result):
28182819
return self._shallow_copy(result, name=name)
28192820

28202821
# TODO: standardize return type of non-union setops type(self vs other)
2822+
@final
28212823
def intersection(self, other, sort=False):
28222824
"""
28232825
Form the intersection of two Index objects.
@@ -3035,9 +3037,17 @@ def symmetric_difference(self, other, result_name=None, sort=None):
30353037
if result_name is None:
30363038
result_name = result_name_update
30373039

3040+
if not self._should_compare(other):
3041+
return self.union(other).rename(result_name)
3042+
elif not is_dtype_equal(self.dtype, other.dtype):
3043+
dtype = find_common_type([self.dtype, other.dtype])
3044+
this = self.astype(dtype, copy=False)
3045+
that = other.astype(dtype, copy=False)
3046+
return this.symmetric_difference(that, sort=sort).rename(result_name)
3047+
30383048
this = self._get_unique_index()
30393049
other = other._get_unique_index()
3040-
indexer = this.get_indexer(other)
3050+
indexer = this.get_indexer_for(other)
30413051

30423052
# {this} minus {other}
30433053
common_indexer = indexer.take((indexer != -1).nonzero()[0])
@@ -3057,7 +3067,7 @@ def symmetric_difference(self, other, result_name=None, sort=None):
30573067
except TypeError:
30583068
pass
30593069

3060-
return Index(the_diff, dtype=self.dtype, name=result_name)
3070+
return Index(the_diff, name=result_name)
30613071

30623072
def _assert_can_do_setop(self, other):
30633073
if not is_list_like(other):

pandas/core/indexes/interval.py

-1
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,6 @@ def func(self, other, sort=sort):
10161016

10171017
_union = _setop("union")
10181018
difference = _setop("difference")
1019-
symmetric_difference = _setop("symmetric_difference")
10201019

10211020
# --------------------------------------------------------------------
10221021

pandas/core/indexes/period.py

-12
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
from pandas.core.dtypes.common import (
1515
is_bool_dtype,
1616
is_datetime64_any_dtype,
17-
is_dtype_equal,
1817
is_float,
1918
is_integer,
20-
is_object_dtype,
2119
is_scalar,
2220
pandas_dtype,
2321
)
@@ -635,16 +633,6 @@ def _setop(self, other, sort, opname: str):
635633
def _intersection(self, other, sort=False):
636634
return self._setop(other, sort, opname="intersection")
637635

638-
def _difference(self, other, sort):
639-
640-
if is_object_dtype(other):
641-
return self.astype(object).difference(other).astype(self.dtype)
642-
643-
elif not is_dtype_equal(self.dtype, other.dtype):
644-
return self
645-
646-
return self._setop(other, sort, opname="difference")
647-
648636
def _union(self, other, sort):
649637
return self._setop(other, sort, opname="_union")
650638

pandas/tests/indexes/interval/test_setops.py

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def test_symmetric_difference(self, closed, sort):
158158
index.left.astype("float64"), index.right, closed=closed
159159
)
160160
result = index.symmetric_difference(other, sort=sort)
161+
expected = empty_index(dtype="float64", closed=closed)
161162
tm.assert_index_equal(result, expected)
162163

163164
@pytest.mark.parametrize(

pandas/tests/indexes/test_setops.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,14 @@ def test_symmetric_difference(self, index):
248248
# GH#10149
249249
cases = [klass(second.values) for klass in [np.array, Series, list]]
250250
for case in cases:
251+
result = first.symmetric_difference(case)
252+
251253
if is_datetime64tz_dtype(first):
252-
with pytest.raises(ValueError, match="Tz-aware"):
253-
# `second.values` casts to tznaive
254-
# TODO: should the symmetric_difference then be the union?
255-
first.symmetric_difference(case)
254+
# second.values casts to tznaive
255+
expected = first.union(case)
256+
tm.assert_index_equal(result, expected)
256257
continue
257-
result = first.symmetric_difference(case)
258+
258259
assert tm.equalContents(result, answer)
259260

260261
if isinstance(index, MultiIndex):
@@ -448,7 +449,9 @@ def test_intersection_difference_match_empty(self, index, sort):
448449
tm.assert_index_equal(inter, diff, exact=True)
449450

450451

451-
@pytest.mark.parametrize("method", ["intersection", "union"])
452+
@pytest.mark.parametrize(
453+
"method", ["intersection", "union", "difference", "symmetric_difference"]
454+
)
452455
def test_setop_with_categorical(index, sort, method):
453456
if isinstance(index, MultiIndex):
454457
# tested separately in tests.indexes.multi.test_setops

0 commit comments

Comments
 (0)