Skip to content

Commit a4573bb

Browse files
St0rmieLeventide
andcommitted
ENH: Implement difference method for Interval and IntervalArray (pandas-dev#21998)
Co-authored-by: Pedro Frigolet <[email protected]>
1 parent 3624ead commit a4573bb

File tree

5 files changed

+374
-0
lines changed

5 files changed

+374
-0
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Other enhancements
4444
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
4545
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4646
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
47+
- Added :meth:`Interval.difference` and :meth:`IntervalArray.difference` to calculate the difference between interval-like objects (:issue:`21998`)
4748
- Added :meth:`Interval.intersection` and :meth:`IntervalArray.intersection` to calculate the intersection between interval-like objects (:issue:`21998`)
4849
- Added :meth:`Interval.union` and :meth:`IntervalArray.union` to calculate the union between interval-like objects (:issue:`21998`)
4950
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)

pandas/_libs/interval.pyx

+132
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,138 @@ cdef class Interval(IntervalMixin):
783783
closed = "neither"
784784
return np.array([Interval(uleft, uright, closed=closed)], dtype=object)
785785

786+
def difference(self, other):
787+
"""
788+
Return the difference between an interval and another.
789+
790+
The difference between two intervals are the points in the first
791+
interval that are not shared with the second interval.
792+
793+
Parameters
794+
----------
795+
other : Interval
796+
Interval to which to calculate the difference.
797+
798+
Returns
799+
-------
800+
np.array
801+
numpy array with two intervals if the second interval is
802+
contained within the first. Array with one interval if
803+
the difference only shortens the limits of the interval.
804+
Empty array if the first interval is contained in the second
805+
and thus there are no points left after difference.
806+
807+
Examples
808+
--------
809+
>>> i0 = pd.Interval(0, 3, closed='right')
810+
>>> i1 = pd.Interval(2, 4, closed='right')
811+
>>> i0.difference(i1)
812+
array([Interval(0, 2, closed='right')], dtype=object)
813+
814+
>>> i2 = pd.Interval(5, 8, closed='right')
815+
>>> i0.intersection(i2)
816+
array([Interval(0, 3, closed='right')], dtype=object)
817+
818+
>>> i3 = pd.Interval(3, 5, closed='left')
819+
>>> i0.difference(i3)
820+
array([Interval(0, 3, closed='neither')], dtype=object)
821+
822+
>>> i4 = pd.Interval(-2, 7, closed='left')
823+
>>> i0.difference(i4)
824+
array([], dtype=object)
825+
826+
>>> i4.difference(i0)
827+
array([Interval(-2, 0, closed='both') Interval(3, 7, closed='neither')],
828+
dtype=object)
829+
"""
830+
if not isinstance(other, Interval):
831+
raise TypeError("`other` must be an Interval, "
832+
f"got {type(other).__name__}")
833+
834+
# if there is no overlap then the difference is the interval
835+
if not self.overlaps(other):
836+
return np.array([self], dtype=object)
837+
838+
# if the first interval is contained inside the other then there's no points
839+
# left after the difference is applied
840+
if self.left > other.left and self.right < other.right:
841+
return np.array([], dtype=object)
842+
843+
# if the intervals limits match but the other interval has closed limits then
844+
# there are no points left after the difference is applied
845+
if (self.left == other.left and self.right == other.right and
846+
other.closed_left and other.closed_right):
847+
return np.array([], dtype=object)
848+
849+
# if the first interval contains the other then the difference is a union of
850+
# two intervals
851+
if self.left < other.left and self.right > other.right:
852+
if self.closed_left and not other.closed_left:
853+
closed1 = "both"
854+
elif self.closed_left:
855+
closed1 = "left"
856+
elif not other.closed_left:
857+
closed1 = "right"
858+
else:
859+
closed1 = "neither"
860+
861+
if self.closed_right and not other.closed_right:
862+
closed2 = "both"
863+
elif self.closed_right:
864+
closed2 = "right"
865+
elif not other.closed_right:
866+
closed2 = "left"
867+
else:
868+
closed2 = "neither"
869+
870+
return np.array([Interval(self.left, other.left, closed1),
871+
Interval(other.right, self.right, closed2)],
872+
dtype=object)
873+
874+
# Define left limit
875+
if self.left < other.left:
876+
dleft = self.left
877+
lclosed = self.closed_left
878+
elif self.left > other.left:
879+
dleft = other.right
880+
lclosed = not other.closed_right
881+
else:
882+
dleft = other.right if other.closed_left else self.left
883+
lclosed = False if other.closed_left else self.closed_left
884+
885+
# Define right limit
886+
if self.right > other.right:
887+
dright = self.right
888+
rclosed = self.closed_right
889+
elif self.right < other.right:
890+
dright = other.left
891+
rclosed = not other.closed_left
892+
else:
893+
dright = self.left if other.closed_right else other.right
894+
rclosed = False if other.closed_right else self.closed_right
895+
896+
# if the interval only contains one point then it must be closed
897+
# on both sides
898+
if dleft == dright:
899+
if (lclosed and self.closed_left) or (rclosed and self.closed_right):
900+
return np.array([Interval(dleft, dright, closed="both")],
901+
dtype=object)
902+
elif not (lclosed and rclosed):
903+
return np.array([], dtype=object)
904+
905+
if dleft > dright:
906+
return np.array([], dtype=object)
907+
908+
if lclosed and rclosed:
909+
closed = "both"
910+
elif lclosed:
911+
closed = "left"
912+
elif rclosed:
913+
closed = "right"
914+
else:
915+
closed = "neither"
916+
return np.array([Interval(dleft, dright, closed=closed)], dtype=object)
917+
786918

787919
@cython.wraparound(False)
788920
@cython.boundscheck(False)

pandas/core/arrays/interval.py

+65
Original file line numberDiff line numberDiff line change
@@ -1510,6 +1510,71 @@ def union(self, other):
15101510

15111511
return np.array([interval.union(other) for interval in self], dtype=object)
15121512

1513+
_interval_shared_docs["difference"] = textwrap.dedent(
1514+
"""
1515+
Calculates difference between each Interval in the %(klass)s and a given
1516+
Interval.
1517+
1518+
The difference between two intervals are the points in the first
1519+
interval that are not shared with the second interval.
1520+
1521+
Parameters
1522+
----------
1523+
other : Interval
1524+
Interval to which to calculate the difference.
1525+
1526+
Returns
1527+
-------
1528+
array
1529+
Array of arrays containing the differences between each interval and other.
1530+
1531+
See Also
1532+
--------
1533+
Interval.difference : Calculate difference between two Interval objects.
1534+
1535+
Examples
1536+
--------
1537+
%(examples)s
1538+
1539+
>>> intervals.difference(pd.Interval(0, 2, 'right'))
1540+
array([[Interval(0, 0, closed='both')],
1541+
[Interval(2, 5, closed='right')],
1542+
[Interval(2, 4, closed='right')]], dtype=object)
1543+
1544+
>>> intervals.difference(pd.Interval(2, 3, closed='left'))
1545+
array([array([Interval(0, 1, closed='right')], dtype=object),
1546+
array([Interval(1, 2, closed='neither'), Interval(3, 5, closed='both')],
1547+
dtype=object) ,
1548+
array([Interval(3, 4, closed='right')], dtype=object)],
1549+
dtype=object)
1550+
"""
1551+
)
1552+
1553+
@Appender(
1554+
_interval_shared_docs["difference"]
1555+
% {
1556+
"klass": "IntervalArray",
1557+
"examples": textwrap.dedent(
1558+
"""\
1559+
>>> data = [(0, 1), (1, 5), (2, 4)]
1560+
>>> intervals = pd.arrays.IntervalArray.from_tuples(data)
1561+
>>> intervals
1562+
<IntervalArray>
1563+
[(0, 1], (1, 5], (2, 4]]
1564+
Length: 3, dtype: interval[int64, right]
1565+
"""
1566+
),
1567+
}
1568+
)
1569+
def difference(self, other):
1570+
if isinstance(other, (IntervalArray, ABCIntervalIndex)):
1571+
raise NotImplementedError
1572+
if not isinstance(other, Interval):
1573+
msg = f"`other` must be Interval-like, got {type(other).__name__}"
1574+
raise TypeError(msg)
1575+
1576+
return np.array([interval.difference(other) for interval in self], dtype=object)
1577+
15131578
# ---------------------------------------------------------------------
15141579

15151580
@property

pandas/tests/arrays/interval/test_overlaps.py

+44
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,47 @@ def test_union_invalid_type(self, other):
177177
msg = f"`other` must be Interval-like, got {type(other).__name__}"
178178
with pytest.raises(TypeError, match=msg):
179179
interval_container.union(other)
180+
181+
182+
class TestDifference:
183+
def test_difference_interval_array(self):
184+
interval = Interval(1, 8, "left")
185+
186+
tuples = [ # Intervals:
187+
(1, 8), # identical
188+
(2, 4), # nested
189+
(0, 9), # spanning
190+
(4, 10), # partial
191+
(-5, 1), # adjacent closed
192+
(8, 10), # adjacent open
193+
(10, 15), # disjoint
194+
]
195+
interval_container = IntervalArray.from_tuples(tuples, "both")
196+
197+
expected = np.array(
198+
[
199+
np.array([Interval(8, 8, "both")], dtype=object),
200+
np.array([], dtype=object),
201+
np.array(
202+
[Interval(0, 1, "left"), Interval(8, 9, "both")], dtype=object
203+
),
204+
np.array([Interval(8, 10, "both")], dtype=object),
205+
np.array([Interval(-5, 1, "left")], dtype=object),
206+
np.array([Interval(8, 10, "both")], dtype=object),
207+
np.array([Interval(10, 15, "both")], dtype=object),
208+
],
209+
dtype=object,
210+
)
211+
result = interval_container.difference(interval)
212+
tm.assert_numpy_array_equal(result, expected)
213+
214+
@pytest.mark.parametrize(
215+
"other",
216+
[10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")],
217+
ids=lambda x: type(x).__name__,
218+
)
219+
def test_difference_invalid_type(self, other):
220+
interval_container = IntervalArray.from_breaks(range(5))
221+
msg = f"`other` must be Interval-like, got {type(other).__name__}"
222+
with pytest.raises(TypeError, match=msg):
223+
interval_container.difference(other)

0 commit comments

Comments
 (0)