Skip to content

Commit 9834d15

Browse files
St0rmieLeventide
andcommitted
ENH: Implement intersection method for Interval and IntervalArray (pandas-dev#21998)
Co-authored-by: Pedro Frigolet <[email protected]>
1 parent bef88ef commit 9834d15

File tree

5 files changed

+282
-0
lines changed

5 files changed

+282
-0
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Other enhancements
4646
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
4747
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4848
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
49+
- Added :meth:`Interval.intersection` and :meth:`IntervalArray.intersection` to calculate the intersection between interval-like objects (:issue:`21998`)
4950
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
5051
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
5152

pandas/_libs/interval.pyx

+75
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,81 @@ cdef class Interval(IntervalMixin):
620620
# (simplifying the negation allows this to be done in less operations)
621621
return op1(self.left, other.right) and op2(other.left, self.right)
622622

623+
def intersection(self, other):
624+
"""
625+
Return the intersection of two intervals.
626+
627+
The intersection of two intervals is the common points shared between both,
628+
including closed endpoints. Open endpoints are not included.
629+
630+
Parameters
631+
----------
632+
other : Interval
633+
Interval to which to calculate the intersection.
634+
635+
Returns
636+
-------
637+
Interval or None
638+
Interval containing the shared points and its closedness or None in
639+
case there's no intersection.
640+
641+
See Also
642+
--------
643+
IntervalArray.intersection : The corresponding method for IntervalArray.
644+
645+
Examples
646+
--------
647+
>>> i0 = pd.Interval(0, 3, closed='right')
648+
>>> i1 = pd.Interval(2, 4, closed='right')
649+
>>> i0.intersection(i1)
650+
Interval(2, 3, closed='right')
651+
652+
Intervals that have no intersection:
653+
654+
>>> i2 = pd.Interval(5, 8, closed='right')
655+
>>> i0.intersection(i2)
656+
None
657+
"""
658+
if not isinstance(other, Interval):
659+
raise TypeError("`other` must be an Interval, "
660+
f"got {type(other).__name__}")
661+
662+
# Define left limit
663+
if self.left < other.left:
664+
ileft = other.left
665+
lclosed = other.closed_left
666+
elif self.left > other.left:
667+
ileft = self.left
668+
lclosed = other.closed_left
669+
else:
670+
ileft = self.left
671+
lclosed = self.closed_left and other.closed_left
672+
673+
# Define right limit
674+
if self.right < other.right:
675+
iright = self.right
676+
rclosed = self.closed_right
677+
elif self.right > other.right:
678+
iright = other.right
679+
rclosed = other.closed_right
680+
else:
681+
iright = self.right
682+
rclosed = self.closed_right and other.closed_right
683+
684+
# No intersection if there is no overlap
685+
if iright < ileft or (iright == ileft and not (lclosed and rclosed)):
686+
return None
687+
688+
if lclosed and rclosed:
689+
closed = "both"
690+
elif lclosed:
691+
closed = "left"
692+
elif rclosed:
693+
closed = "right"
694+
else:
695+
closed = "neither"
696+
return Interval(ileft, iright, closed=closed)
697+
623698

624699
@cython.wraparound(False)
625700
@cython.boundscheck(False)

pandas/core/arrays/interval.py

+64
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,70 @@ def overlaps(self, other):
13801380
# (simplifying the negation allows this to be done in less operations)
13811381
return op1(self.left, other.right) & op2(other.left, self.right)
13821382

1383+
_interval_shared_docs["intersection"] = textwrap.dedent(
1384+
"""
1385+
Calculates intersection between all intervals in the %(klass)s and a given
1386+
Interval.
1387+
1388+
The intersection of two intervals is the common points shared between both,
1389+
including closed endpoints. Open endpoints are not included.
1390+
1391+
Parameters
1392+
----------
1393+
other : Interval
1394+
Interval to which to calculate the intersection.
1395+
1396+
Returns
1397+
-------
1398+
array
1399+
Array containing the Intersections between each interval and other.
1400+
1401+
See Also
1402+
--------
1403+
Interval.intersection : Calculate intersection between two Interval objects.
1404+
1405+
Examples
1406+
--------
1407+
%(examples)s
1408+
1409+
>>> intervals.intersection(pd.Interval(0, 2, 'right'))
1410+
array([Interval(0, 1, closed='right'), Interval(1, 2, closed='right'),
1411+
None], dtype=object)
1412+
1413+
Intersection with a single value:
1414+
1415+
>>> intervals.intersection(pd.Interval(5, 8, closed='left'))
1416+
array([None, Interval(5, 5, closed='both'), None], dtype=object)
1417+
"""
1418+
)
1419+
1420+
@Appender(
1421+
_interval_shared_docs["intersection"]
1422+
% {
1423+
"klass": "IntervalArray",
1424+
"examples": textwrap.dedent(
1425+
"""\
1426+
>>> data = [(0, 1), (1, 5), (2, 4)]
1427+
>>> intervals = pd.arrays.IntervalArray.from_tuples(data)
1428+
>>> intervals
1429+
<IntervalArray>
1430+
[(0, 1], (1, 5], (2, 4]]
1431+
Length: 3, dtype: interval[int64, right]
1432+
"""
1433+
),
1434+
}
1435+
)
1436+
def intersection(self, other):
1437+
if isinstance(other, (IntervalArray, ABCIntervalIndex)):
1438+
raise NotImplementedError
1439+
if not isinstance(other, Interval):
1440+
msg = f"`other` must be Interval-like, got {type(other).__name__}"
1441+
raise TypeError(msg)
1442+
1443+
return np.array(
1444+
[interval.intersection(other) for interval in self], dtype=object
1445+
)
1446+
13831447
# ---------------------------------------------------------------------
13841448

13851449
@property

pandas/tests/arrays/interval/test_overlaps.py

+41
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,44 @@ def test_overlaps_invalid_type(self, constructor, other):
9292
msg = f"`other` must be Interval-like, got {type(other).__name__}"
9393
with pytest.raises(TypeError, match=msg):
9494
interval_container.overlaps(other)
95+
96+
97+
class TestIntersection:
98+
def test_intersection_interval_array(self):
99+
interval = Interval(1, 8, "left")
100+
101+
tuples = [ # Intervals:
102+
(1, 8), # identical
103+
(2, 4), # nested
104+
(0, 9), # spanning
105+
(4, 10), # partial
106+
(-5, 1), # adjacent closed
107+
(8, 10), # adjacent open
108+
(10, 15), # disjoint
109+
]
110+
interval_container = IntervalArray.from_tuples(tuples, "both")
111+
112+
expected = np.array(
113+
[
114+
Interval(1, 8, "left"),
115+
Interval(2, 4, "both"),
116+
Interval(1, 8, "left"),
117+
Interval(4, 8, "left"),
118+
Interval(1, 1, "both"),
119+
None,
120+
None,
121+
]
122+
)
123+
result = interval_container.intersection(interval)
124+
tm.assert_numpy_array_equal(result, expected)
125+
126+
@pytest.mark.parametrize(
127+
"other",
128+
[10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")],
129+
ids=lambda x: type(x).__name__,
130+
)
131+
def test_intersection_invalid_type(self, other):
132+
interval_container = IntervalArray.from_breaks(range(5))
133+
msg = f"`other` must be Interval-like, got {type(other).__name__}"
134+
with pytest.raises(TypeError, match=msg):
135+
interval_container.intersection(other)

pandas/tests/scalar/interval/test_overlaps.py

+101
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import numpy as np
12
import pytest
23

34
from pandas import (
45
Interval,
56
Timedelta,
67
Timestamp,
78
)
9+
import pandas._testing as tm
810

911

1012
@pytest.fixture(
@@ -65,3 +67,102 @@ def test_overlaps_invalid_type(self, other):
6567
msg = f"`other` must be an Interval, got {type(other).__name__}"
6668
with pytest.raises(TypeError, match=msg):
6769
interval.overlaps(other)
70+
71+
72+
class TestIntersection:
73+
def test_intersection_self(self):
74+
interval = Interval(1, 8, "left")
75+
assert interval.intersection(interval) == interval
76+
77+
def test_intersection_include_limits(self):
78+
other = Interval(1, 8, "left")
79+
80+
intervals = np.array(
81+
[
82+
Interval(7, 9, "left"), # include left
83+
Interval(0, 2, "right"), # include right
84+
Interval(1, 8, "right"), # open limit
85+
]
86+
)
87+
88+
expected = np.array(
89+
[
90+
Interval(7, 8, "left"),
91+
Interval(1, 2, "both"),
92+
Interval(1, 8, "neither"),
93+
]
94+
)
95+
96+
result = np.array([interval.intersection(other) for interval in intervals])
97+
tm.assert_numpy_array_equal(result, expected)
98+
99+
def test_intersection_overlapping(self):
100+
other = Interval(1, 8, "left")
101+
102+
intervals = np.array(
103+
[
104+
Interval(2, 4, "both"), # nested
105+
Interval(0, 9, "both"), # spanning
106+
Interval(4, 10, "both"), # partial
107+
]
108+
)
109+
110+
expected = np.array(
111+
[
112+
Interval(2, 4, "both"),
113+
Interval(1, 8, "left"),
114+
Interval(4, 8, "left"),
115+
]
116+
)
117+
118+
result = np.array([interval.intersection(other) for interval in intervals])
119+
tm.assert_numpy_array_equal(result, expected)
120+
121+
def test_intersection_adjacent(self):
122+
other = Interval(1, 8, "left")
123+
124+
intervals = np.array(
125+
[
126+
Interval(-5, 1, "both"), # adjacent closed
127+
Interval(8, 10, "both"), # adjacent open
128+
Interval(10, 15, "both"), # disjoint
129+
]
130+
)
131+
132+
expected = np.array(
133+
[
134+
Interval(1, 1, "both"),
135+
None,
136+
None,
137+
]
138+
)
139+
140+
result = np.array([interval.intersection(other) for interval in intervals])
141+
tm.assert_numpy_array_equal(result, expected)
142+
143+
def test_intersection_timestamps(self):
144+
year_2020 = Interval(
145+
Timestamp("2020-01-01 00:00:00"),
146+
Timestamp("2021-01-01 00:00:00"),
147+
closed="left",
148+
)
149+
150+
march_2020 = Interval(
151+
Timestamp("2020-03-01 00:00:00"),
152+
Timestamp("2020-04-01 00:00:00"),
153+
closed="left",
154+
)
155+
156+
result = year_2020.intersection(march_2020)
157+
assert result == march_2020
158+
159+
@pytest.mark.parametrize(
160+
"other",
161+
[10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")],
162+
ids=lambda x: type(x).__name__,
163+
)
164+
def test_intersection_invalid_type(self, other):
165+
interval = Interval(0, 1)
166+
msg = f"`other` must be an Interval, got {type(other).__name__}"
167+
with pytest.raises(TypeError, match=msg):
168+
interval.intersection(other)

0 commit comments

Comments
 (0)