Skip to content

Commit 6f55671

Browse files
committed
ENH: return RangeIndex from difference, symmetric_difference where possible closes pandas-dev#12034
1 parent d9722ef commit 6f55671

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

pandas/core/indexes/range.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,9 @@ def equals(self, other: object) -> bool:
466466
return self._range == other._range
467467
return super().equals(other)
468468

469+
# --------------------------------------------------------------------
470+
# Set Operations
471+
469472
def intersection(self, other, sort=False):
470473
"""
471474
Form the intersection of two Index objects.
@@ -632,6 +635,57 @@ def _union(self, other, sort):
632635
return type(self)(start_r, end_r + step_o, step_o)
633636
return self._int64index._union(other, sort=sort)
634637

638+
def difference(self, other, sort=None):
639+
# optimized set operation if we have another RangeIndex
640+
self._validate_sort_keyword(sort)
641+
642+
if not isinstance(other, RangeIndex):
643+
return super().difference(other, sort=sort)
644+
645+
res_name = ops.get_op_result_name(self, other)
646+
647+
first = self._range[::-1] if self.step < 0 else self._range
648+
overlap = self.intersection(other)
649+
if overlap.step < 0:
650+
overlap = overlap[::-1]
651+
652+
if len(overlap) == 0:
653+
return self._shallow_copy(name=res_name)
654+
if len(overlap) == len(self):
655+
return self[:0].rename(res_name)
656+
if not isinstance(overlap, RangeIndex):
657+
# We wont end up with RangeIndex, so fall back
658+
return super().difference(other, sort=sort)
659+
660+
if overlap[0] == first.start:
661+
# The difference is everything after the intersection
662+
new_rng = range(overlap[-1] + first.step, first.stop, first.step)
663+
elif overlap[-1] == first.stop:
664+
# The difference is everything before the intersection
665+
new_rng = range(first.start, overlap[0] - first.step, first.step)
666+
else:
667+
# The difference is not range-like
668+
return super().difference(other, sort=sort)
669+
670+
new_index = type(self)._simple_new(new_rng, name=res_name)
671+
if first is not self._range:
672+
new_index = new_index[::-1]
673+
return new_index
674+
675+
def symmetric_difference(self, other, result_name=None, sort=None):
676+
if not isinstance(other, RangeIndex) or sort is not None:
677+
return super().symmetric_difference(other, result_name, sort)
678+
679+
left = self.difference(other)
680+
right = other.difference(self)
681+
result = left.union(right)
682+
683+
if result_name is not None:
684+
result = result.rename(result_name)
685+
return result
686+
687+
# --------------------------------------------------------------------
688+
635689
@doc(Int64Index.join)
636690
def join(self, other, how="left", level=None, return_indexers=False, sort=False):
637691
if how == "outer" and self is not other:
@@ -744,12 +798,17 @@ def __floordiv__(self, other):
744798
return self._simple_new(new_range, name=self.name)
745799
return self._int64index // other
746800

801+
# --------------------------------------------------------------------
802+
# Reductions
803+
747804
def all(self) -> bool:
748805
return 0 not in self._range
749806

750807
def any(self) -> bool:
751808
return any(self._range)
752809

810+
# --------------------------------------------------------------------
811+
753812
@classmethod
754813
def _add_numeric_methods_binary(cls):
755814
""" add in numeric methods, specialized to RangeIndex """

pandas/tests/indexes/ranges/test_setops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,51 @@ def test_union_sorted(self, unions):
239239
res3 = idx1._int64index.union(idx2, sort=None)
240240
tm.assert_index_equal(res2, expected_sorted, exact=True)
241241
tm.assert_index_equal(res3, expected_sorted)
242+
243+
def test_difference(self):
244+
# GH#12034 Cases where we operate against another RangeIndex and may
245+
# get back another RangeIndex
246+
obj = RangeIndex.from_range(range(1, 10))
247+
248+
result = obj.difference(obj)
249+
expected = RangeIndex.from_range(range(0))
250+
tm.assert_index_equal(result, expected)
251+
252+
result = obj.difference(expected)
253+
tm.assert_index_equal(result, obj)
254+
255+
result = obj.difference(obj[:3])
256+
tm.assert_index_equal(result, obj[3:])
257+
258+
result = obj.difference(obj[-3:])
259+
tm.assert_index_equal(result, obj[:-3])
260+
261+
result = obj.difference(obj[2:6])
262+
expected = Int64Index([1, 2, 7, 8, 9])
263+
tm.assert_index_equal(result, expected)
264+
265+
def test_symmetric_difference(self):
266+
# GH#12034 Cases where we operate against another RangeIndex and may
267+
# get back another RangeIndex
268+
left = RangeIndex.from_range(range(1, 10))
269+
270+
result = left.symmetric_difference(left)
271+
expected = RangeIndex.from_range(range(0))
272+
tm.assert_index_equal(result, expected)
273+
274+
result = left.symmetric_difference(expected)
275+
tm.assert_index_equal(result, left)
276+
277+
result = left[:-2].symmetric_difference(left[2:])
278+
expected = Int64Index([1, 2, 8, 9])
279+
tm.assert_index_equal(result, expected)
280+
281+
right = RangeIndex.from_range(range(10, 15))
282+
283+
result = left.symmetric_difference(right)
284+
expected = RangeIndex.from_range(range(1, 15))
285+
tm.assert_index_equal(result, expected)
286+
287+
result = left.symmetric_difference(right[1:])
288+
expected = Int64Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14])
289+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)