Skip to content

Commit 4f674a1

Browse files
authored
ENH: return RangeIndex from difference, symmetric_difference (#36564)
1 parent 09e7829 commit 4f674a1

File tree

3 files changed

+108
-0
lines changed

3 files changed

+108
-0
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ Other enhancements
169169
- :meth:`Rolling.mean()` and :meth:`Rolling.sum()` use Kahan summation to calculate the mean to avoid numerical problems (:issue:`10319`, :issue:`11645`, :issue:`13254`, :issue:`32761`, :issue:`36031`)
170170
- :meth:`DatetimeIndex.searchsorted`, :meth:`TimedeltaIndex.searchsorted`, :meth:`PeriodIndex.searchsorted`, and :meth:`Series.searchsorted` with datetimelike dtypes will now try to cast string arguments (listlike and scalar) to the matching datetimelike type (:issue:`36346`)
171171
- Added methods :meth:`IntegerArray.prod`, :meth:`IntegerArray.min`, and :meth:`IntegerArray.max` (:issue:`33790`)
172+
- Where possible :meth:`RangeIndex.difference` and :meth:`RangeIndex.symmetric_difference` will return :class:`RangeIndex` instead of :class:`Int64Index` (:issue:`36564`)
172173

173174
.. _whatsnew_120.api_breaking.python:
174175

pandas/core/indexes/range.py

+59
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,9 @@ def equals(self, other: object) -> bool:
468468
return self._range == other._range
469469
return super().equals(other)
470470

471+
# --------------------------------------------------------------------
472+
# Set Operations
473+
471474
def intersection(self, other, sort=False):
472475
"""
473476
Form the intersection of two Index objects.
@@ -634,6 +637,57 @@ def _union(self, other, sort):
634637
return type(self)(start_r, end_r + step_o, step_o)
635638
return self._int64index._union(other, sort=sort)
636639

640+
def difference(self, other, sort=None):
641+
# optimized set operation if we have another RangeIndex
642+
self._validate_sort_keyword(sort)
643+
644+
if not isinstance(other, RangeIndex):
645+
return super().difference(other, sort=sort)
646+
647+
res_name = ops.get_op_result_name(self, other)
648+
649+
first = self._range[::-1] if self.step < 0 else self._range
650+
overlap = self.intersection(other)
651+
if overlap.step < 0:
652+
overlap = overlap[::-1]
653+
654+
if len(overlap) == 0:
655+
return self._shallow_copy(name=res_name)
656+
if len(overlap) == len(self):
657+
return self[:0].rename(res_name)
658+
if not isinstance(overlap, RangeIndex):
659+
# We wont end up with RangeIndex, so fall back
660+
return super().difference(other, sort=sort)
661+
662+
if overlap[0] == first.start:
663+
# The difference is everything after the intersection
664+
new_rng = range(overlap[-1] + first.step, first.stop, first.step)
665+
elif overlap[-1] == first.stop:
666+
# The difference is everything before the intersection
667+
new_rng = range(first.start, overlap[0] - first.step, first.step)
668+
else:
669+
# The difference is not range-like
670+
return super().difference(other, sort=sort)
671+
672+
new_index = type(self)._simple_new(new_rng, name=res_name)
673+
if first is not self._range:
674+
new_index = new_index[::-1]
675+
return new_index
676+
677+
def symmetric_difference(self, other, result_name=None, sort=None):
678+
if not isinstance(other, RangeIndex) or sort is not None:
679+
return super().symmetric_difference(other, result_name, sort)
680+
681+
left = self.difference(other)
682+
right = other.difference(self)
683+
result = left.union(right)
684+
685+
if result_name is not None:
686+
result = result.rename(result_name)
687+
return result
688+
689+
# --------------------------------------------------------------------
690+
637691
@doc(Int64Index.join)
638692
def join(self, other, how="left", level=None, return_indexers=False, sort=False):
639693
if how == "outer" and self is not other:
@@ -746,12 +800,17 @@ def __floordiv__(self, other):
746800
return self._simple_new(new_range, name=self.name)
747801
return self._int64index // other
748802

803+
# --------------------------------------------------------------------
804+
# Reductions
805+
749806
def all(self) -> bool:
750807
return 0 not in self._range
751808

752809
def any(self) -> bool:
753810
return any(self._range)
754811

812+
# --------------------------------------------------------------------
813+
755814
@classmethod
756815
def _add_numeric_methods_binary(cls):
757816
""" add in numeric methods, specialized to RangeIndex """

pandas/tests/indexes/ranges/test_setops.py

+48
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,51 @@ def test_union_sorted(self, unions):
239239
res3 = idx1._int64index.union(idx2, sort=None)
240240
tm.assert_index_equal(res2, expected_sorted, exact=True)
241241
tm.assert_index_equal(res3, expected_sorted)
242+
243+
def test_difference(self):
244+
# GH#12034 Cases where we operate against another RangeIndex and may
245+
# get back another RangeIndex
246+
obj = RangeIndex.from_range(range(1, 10), name="foo")
247+
248+
result = obj.difference(obj)
249+
expected = RangeIndex.from_range(range(0), name="foo")
250+
tm.assert_index_equal(result, expected)
251+
252+
result = obj.difference(expected.rename("bar"))
253+
tm.assert_index_equal(result, obj.rename(None))
254+
255+
result = obj.difference(obj[:3])
256+
tm.assert_index_equal(result, obj[3:])
257+
258+
result = obj.difference(obj[-3:])
259+
tm.assert_index_equal(result, obj[:-3])
260+
261+
result = obj.difference(obj[2:6])
262+
expected = Int64Index([1, 2, 7, 8, 9], name="foo")
263+
tm.assert_index_equal(result, expected)
264+
265+
def test_symmetric_difference(self):
266+
# GH#12034 Cases where we operate against another RangeIndex and may
267+
# get back another RangeIndex
268+
left = RangeIndex.from_range(range(1, 10), name="foo")
269+
270+
result = left.symmetric_difference(left)
271+
expected = RangeIndex.from_range(range(0), name="foo")
272+
tm.assert_index_equal(result, expected)
273+
274+
result = left.symmetric_difference(expected.rename("bar"))
275+
tm.assert_index_equal(result, left.rename(None))
276+
277+
result = left[:-2].symmetric_difference(left[2:])
278+
expected = Int64Index([1, 2, 8, 9], name="foo")
279+
tm.assert_index_equal(result, expected)
280+
281+
right = RangeIndex.from_range(range(10, 15))
282+
283+
result = left.symmetric_difference(right)
284+
expected = RangeIndex.from_range(range(1, 15))
285+
tm.assert_index_equal(result, expected)
286+
287+
result = left.symmetric_difference(right[1:])
288+
expected = Int64Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14])
289+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)