Skip to content

Commit 0e70741

Browse files
St0rmieLeventide
andcommitted
ENH: Implement union method for Interval and IntervalArray (pandas-dev#21998)
Co-authored-by: Pedro Frigolet <[email protected]>
1 parent 9834d15 commit 0e70741

File tree

5 files changed

+318
-0
lines changed

5 files changed

+318
-0
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Other enhancements
4747
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4848
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
4949
- Added :meth:`Interval.intersection` and :meth:`IntervalArray.intersection` to calculate the intersection between interval-like objects (:issue:`21998`)
50+
- Added :meth:`Interval.union` and :meth:`IntervalArray.union` to calculate the union between interval-like objects (:issue:`21998`)
5051
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
5152
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
5253

pandas/_libs/interval.pyx

+88
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,94 @@ cdef class Interval(IntervalMixin):
695695
closed = "neither"
696696
return Interval(ileft, iright, closed=closed)
697697

698+
def union(self, other):
699+
"""
700+
Return the union of two intervals.
701+
702+
The union of two intervals are all the values in both, including
703+
closed endpoints.
704+
705+
Parameters
706+
----------
707+
other : Interval
708+
Interval with which to create a union.
709+
710+
Returns
711+
-------
712+
np.array
713+
numpy array with one interval if there is overlap between
714+
the two intervals, with two intervals if there is no overlap.
715+
716+
See Also
717+
--------
718+
IntervalArray.union : The corresponding method for IntervalArray.
719+
720+
Examples
721+
--------
722+
>>> i0 = pd.Interval(0, 3, closed='right')
723+
>>> i1 = pd.Interval(2, 4, closed='right')
724+
>>> i0.union(i1)
725+
array([Interval(0, 4, closed='right')], dtype=object)
726+
727+
>>> i2 = pd.Interval(5, 8, closed='right')
728+
>>> i0.union(i2)
729+
array([Interval(0, 3, closed='right') Interval(5, 8, closed='right')],
730+
dtype=object)
731+
732+
>>> i3 = pd.Interval(3, 5, closed='right')
733+
>>> i0.union(i3)
734+
array([Interval(0, 5, closed='right')], dtype=object)
735+
"""
736+
if not isinstance(other, Interval):
737+
raise TypeError("`other` must be an Interval, "
738+
f"got {type(other).__name__}")
739+
740+
# if there is no overlap return the two intervals
741+
# except if the two intervals share an endpoint were one side is closed
742+
if not self.overlaps(other):
743+
if(not(
744+
(self.left == other.right and
745+
(self.closed_left or other.closed_right))
746+
or
747+
(self.right == other.left and
748+
(self.closed_right or other.closed_left)))):
749+
if self.left < other.left:
750+
return np.array([self, other], dtype=object)
751+
else:
752+
return np.array([other, self], dtype=object)
753+
754+
# Define left limit
755+
if self.left < other.left:
756+
uleft = self.left
757+
lclosed = self.closed_left
758+
elif self.left > other.left:
759+
uleft = other.left
760+
lclosed = other.closed_left
761+
else:
762+
uleft = self.left
763+
lclosed = self.closed_left or other.closed_left
764+
765+
# Define right limit
766+
if self.right > other.right:
767+
uright = self.right
768+
rclosed = self.closed_right
769+
elif self.right < other.right:
770+
uright = other.right
771+
rclosed = other.closed_right
772+
else:
773+
uright = self.right
774+
rclosed = self.closed_right or other.closed_right
775+
776+
if lclosed and rclosed:
777+
closed = "both"
778+
elif lclosed:
779+
closed = "left"
780+
elif rclosed:
781+
closed = "right"
782+
else:
783+
closed = "neither"
784+
return np.array([Interval(uleft, uright, closed=closed)], dtype=object)
785+
698786

699787
@cython.wraparound(False)
700788
@cython.boundscheck(False)

pandas/core/arrays/interval.py

+66
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,72 @@ def intersection(self, other):
14441444
[interval.intersection(other) for interval in self], dtype=object
14451445
)
14461446

1447+
_interval_shared_docs["union"] = textwrap.dedent(
1448+
"""
1449+
Calculates union between each interval in the %(klass)s and a given
1450+
Interval.
1451+
1452+
The union of two intervals are all the values in both, including
1453+
closed endpoints.
1454+
1455+
Parameters
1456+
----------
1457+
other : Interval
1458+
Interval to which to calculate the union.
1459+
1460+
Returns
1461+
-------
1462+
array
1463+
Array of arrays containing the unions between each interval and other.
1464+
1465+
See Also
1466+
--------
1467+
Interval.union : Calculate union between two Interval objects.
1468+
1469+
Examples
1470+
--------
1471+
%(examples)s
1472+
1473+
>>> intervals.union(pd.Interval(0, 2, 'right'))
1474+
array([[Interval(0, 2, closed='right')],
1475+
[Interval(0, 5, closed='right')],
1476+
[Interval(0, 4, closed='right')]], dtype=object)
1477+
1478+
>>> intervals.union(pd.Interval(5, 8, closed='left'))
1479+
array([array([Interval(0, 1, closed='right'), Interval(5, 8, closed='left')],
1480+
dtype=object) ,
1481+
array([Interval(1, 8, closed='neither')], dtype=object),
1482+
array([Interval(2, 4, closed='right'), Interval(5, 8, closed='left')],
1483+
dtype=object) ],
1484+
dtype=object)
1485+
"""
1486+
)
1487+
1488+
@Appender(
1489+
_interval_shared_docs["union"]
1490+
% {
1491+
"klass": "IntervalArray",
1492+
"examples": textwrap.dedent(
1493+
"""\
1494+
>>> data = [(0, 1), (1, 5), (2, 4)]
1495+
>>> intervals = pd.arrays.IntervalArray.from_tuples(data)
1496+
>>> intervals
1497+
<IntervalArray>
1498+
[(0, 1], (1, 5], (2, 4]]
1499+
Length: 3, dtype: interval[int64, right]
1500+
"""
1501+
),
1502+
}
1503+
)
1504+
def union(self, other):
1505+
if isinstance(other, (IntervalArray, ABCIntervalIndex)):
1506+
raise NotImplementedError
1507+
if not isinstance(other, Interval):
1508+
msg = f"`other` must be Interval-like, got {type(other).__name__}"
1509+
raise TypeError(msg)
1510+
1511+
return np.array([interval.union(other) for interval in self], dtype=object)
1512+
14471513
# ---------------------------------------------------------------------
14481514

14491515
@property

pandas/tests/arrays/interval/test_overlaps.py

+44
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,47 @@ def test_intersection_invalid_type(self, other):
133133
msg = f"`other` must be Interval-like, got {type(other).__name__}"
134134
with pytest.raises(TypeError, match=msg):
135135
interval_container.intersection(other)
136+
137+
138+
class TestUnion:
139+
def test_union_interval_array(self):
140+
interval = Interval(1, 8, "left")
141+
142+
tuples = [ # Intervals:
143+
(1, 8), # identical
144+
(2, 4), # nested
145+
(0, 9), # spanning
146+
(4, 10), # partial
147+
(-5, 1), # adjacent closed
148+
(8, 10), # adjacent open
149+
(10, 15), # disjoint
150+
]
151+
interval_container = IntervalArray.from_tuples(tuples, "both")
152+
153+
expected = np.array(
154+
[
155+
np.array([Interval(1, 8, "both")], dtype=object),
156+
np.array([Interval(1, 8, "left")], dtype=object),
157+
np.array([Interval(0, 9, "both")], dtype=object),
158+
np.array([Interval(1, 10, "both")], dtype=object),
159+
np.array([Interval(-5, 8, "left")], dtype=object),
160+
np.array([Interval(1, 10, "both")], dtype=object),
161+
np.array(
162+
[Interval(1, 8, "left"), Interval(10, 15, "both")], dtype=object
163+
),
164+
],
165+
dtype=object,
166+
)
167+
result = interval_container.union(interval)
168+
tm.assert_numpy_array_equal(result, expected)
169+
170+
@pytest.mark.parametrize(
171+
"other",
172+
[10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")],
173+
ids=lambda x: type(x).__name__,
174+
)
175+
def test_union_invalid_type(self, other):
176+
interval_container = IntervalArray.from_breaks(range(5))
177+
msg = f"`other` must be Interval-like, got {type(other).__name__}"
178+
with pytest.raises(TypeError, match=msg):
179+
interval_container.union(other)

pandas/tests/scalar/interval/test_overlaps.py

+119
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,122 @@ def test_intersection_invalid_type(self, other):
166166
msg = f"`other` must be an Interval, got {type(other).__name__}"
167167
with pytest.raises(TypeError, match=msg):
168168
interval.intersection(other)
169+
170+
171+
class TestUnion:
172+
def test_union_self(self):
173+
interval = Interval(1, 8, "left")
174+
175+
result = interval.union(interval)
176+
177+
expected = np.array([interval], dtype=object)
178+
tm.assert_numpy_array_equal(result, expected)
179+
180+
def test_union_include_limits(self):
181+
other = Interval(1, 8, "left")
182+
183+
intervals = np.array(
184+
[
185+
Interval(7, 9, "left"), # include left
186+
Interval(0, 2, "right"), # include right
187+
Interval(1, 8, "right"), # open limit
188+
]
189+
)
190+
191+
expected = np.array(
192+
[
193+
np.array([Interval(1, 9, "left")], dtype=object),
194+
np.array([Interval(0, 8, "neither")], dtype=object),
195+
np.array([Interval(1, 8, "both")], dtype=object),
196+
],
197+
dtype=object,
198+
)
199+
200+
result = np.array([interval.union(other) for interval in intervals])
201+
tm.assert_numpy_array_equal(result, expected)
202+
203+
def test_union_overlapping(self):
204+
other = Interval(1, 8, "left")
205+
206+
intervals = np.array(
207+
[
208+
Interval(2, 4, "both"), # nested
209+
Interval(0, 9, "both"), # spanning
210+
Interval(4, 10, "both"), # partial
211+
]
212+
)
213+
214+
expected = np.array(
215+
[
216+
np.array([Interval(1, 8, "left")], dtype=object),
217+
np.array([Interval(0, 9, "both")], dtype=object),
218+
np.array([Interval(1, 10, "both")], dtype=object),
219+
],
220+
dtype=object,
221+
)
222+
223+
result = np.array([interval.union(other) for interval in intervals])
224+
tm.assert_numpy_array_equal(result, expected)
225+
226+
def test_union_adjacent(self):
227+
other = Interval(1, 8, "left")
228+
229+
intervals = np.array(
230+
[
231+
Interval(-5, 1, "both"), # adjacent closed
232+
Interval(8, 10, "both"), # adjacent open
233+
Interval(10, 15, "both"), # disjoint
234+
]
235+
)
236+
237+
expected = np.array(
238+
[
239+
np.array([Interval(-5, 8, "left")], dtype=object),
240+
np.array([Interval(1, 10, "both")], dtype=object),
241+
np.array([other, Interval(10, 15, "both")], dtype=object),
242+
],
243+
dtype=object,
244+
)
245+
246+
result = np.array(
247+
[interval.union(other) for interval in intervals], dtype=object
248+
)
249+
tm.assert_numpy_array_equal(result, expected)
250+
251+
def test_union_timestamps(self):
252+
year_2020 = Interval(
253+
Timestamp("2020-01-01 00:00:00"),
254+
Timestamp("2021-01-01 00:00:00"),
255+
closed="left",
256+
)
257+
258+
year_2021 = Interval(
259+
Timestamp("2021-01-01 00:00:00"),
260+
Timestamp("2022-01-01 00:00:00"),
261+
closed="left",
262+
)
263+
264+
expected = np.array(
265+
[
266+
Interval(
267+
Timestamp("2020-01-01 00:00:00"),
268+
Timestamp("2022-01-01 00:00:00"),
269+
closed="left",
270+
)
271+
],
272+
dtype=object,
273+
)
274+
275+
result = year_2020.union(year_2021)
276+
tm.assert_numpy_array_equal(result, expected)
277+
278+
@pytest.mark.parametrize(
279+
"other",
280+
[10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")],
281+
ids=lambda x: type(x).__name__,
282+
)
283+
def test_union_invalid_type(self, other):
284+
interval = Interval(0, 1)
285+
msg = f"`other` must be an Interval, got {type(other).__name__}"
286+
with pytest.raises(TypeError, match=msg):
287+
interval.union(other)

0 commit comments

Comments
 (0)