diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 130ccded72859..73e3aa295c89c 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -46,6 +46,9 @@ Other enhancements - :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`) - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) +- Added :meth:`Interval.difference` and :meth:`IntervalArray.difference` to calculate the difference between interval-like objects (:issue:`21998`) +- Added :meth:`Interval.intersection` and :meth:`IntervalArray.intersection` to calculate the intersection between interval-like objects (:issue:`21998`) +- Added :meth:`Interval.union` and :meth:`IntervalArray.union` to calculate the union between interval-like objects (:issue:`21998`) - Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) diff --git a/pandas/_libs/interval.pyx b/pandas/_libs/interval.pyx index 564019d7c0d8c..944b71c4a3745 100644 --- a/pandas/_libs/interval.pyx +++ b/pandas/_libs/interval.pyx @@ -620,6 +620,301 @@ cdef class Interval(IntervalMixin): # (simplifying the negation allows this to be done in less operations) return op1(self.left, other.right) and op2(other.left, self.right) + def intersection(self, other): + """ + Return the intersection of two intervals. + + The intersection of two intervals is the common points shared between both, + including closed endpoints. Open endpoints are not included. + + Parameters + ---------- + other : Interval + Interval to which to calculate the intersection. + + Returns + ------- + Interval or None + Interval containing the shared points and its closedness or None in + case there's no intersection. + + See Also + -------- + IntervalArray.intersection : The corresponding method for IntervalArray. + + Examples + -------- + >>> i0 = pd.Interval(0, 3, closed='right') + >>> i1 = pd.Interval(2, 4, closed='right') + >>> i0.intersection(i1) + Interval(2, 3, closed='right') + + Intervals that have no intersection: + + >>> i2 = pd.Interval(5, 8, closed='right') + >>> i0.intersection(i2) + None + """ + if not isinstance(other, Interval): + raise TypeError("`other` must be an Interval, " + f"got {type(other).__name__}") + + # Define left limit + if self.left < other.left: + ileft = other.left + lclosed = other.closed_left + elif self.left > other.left: + ileft = self.left + lclosed = other.closed_left + else: + ileft = self.left + lclosed = self.closed_left and other.closed_left + + # Define right limit + if self.right < other.right: + iright = self.right + rclosed = self.closed_right + elif self.right > other.right: + iright = other.right + rclosed = other.closed_right + else: + iright = self.right + rclosed = self.closed_right and other.closed_right + + # No intersection if there is no overlap + if iright < ileft or (iright == ileft and not (lclosed and rclosed)): + return None + + if lclosed and rclosed: + closed = "both" + elif lclosed: + closed = "left" + elif rclosed: + closed = "right" + else: + closed = "neither" + return Interval(ileft, iright, closed=closed) + + def union(self, other): + """ + Return the union of two intervals. + + The union of two intervals are all the values in both, including + closed endpoints. + + Parameters + ---------- + other : Interval + Interval with which to create a union. + + Returns + ------- + np.array + numpy array with one interval if there is overlap between + the two intervals, with two intervals if there is no overlap. + + See Also + -------- + IntervalArray.union : The corresponding method for IntervalArray. + + Examples + -------- + >>> i0 = pd.Interval(0, 3, closed='right') + >>> i1 = pd.Interval(2, 4, closed='right') + >>> i0.union(i1) + array([Interval(0, 4, closed='right')], dtype=object) + + >>> i2 = pd.Interval(5, 8, closed='right') + >>> i0.union(i2) + array([Interval(0, 3, closed='right') Interval(5, 8, closed='right')], + dtype=object) + + >>> i3 = pd.Interval(3, 5, closed='right') + >>> i0.union(i3) + array([Interval(0, 5, closed='right')], dtype=object) + """ + if not isinstance(other, Interval): + raise TypeError("`other` must be an Interval, " + f"got {type(other).__name__}") + + # if there is no overlap return the two intervals + # except if the two intervals share an endpoint were one side is closed + if not self.overlaps(other): + if(not( + (self.left == other.right and + (self.closed_left or other.closed_right)) + or + (self.right == other.left and + (self.closed_right or other.closed_left)))): + if self.left < other.left: + return np.array([self, other], dtype=object) + else: + return np.array([other, self], dtype=object) + + # Define left limit + if self.left < other.left: + uleft = self.left + lclosed = self.closed_left + elif self.left > other.left: + uleft = other.left + lclosed = other.closed_left + else: + uleft = self.left + lclosed = self.closed_left or other.closed_left + + # Define right limit + if self.right > other.right: + uright = self.right + rclosed = self.closed_right + elif self.right < other.right: + uright = other.right + rclosed = other.closed_right + else: + uright = self.right + rclosed = self.closed_right or other.closed_right + + if lclosed and rclosed: + closed = "both" + elif lclosed: + closed = "left" + elif rclosed: + closed = "right" + else: + closed = "neither" + return np.array([Interval(uleft, uright, closed=closed)], dtype=object) + + def difference(self, other): + """ + Return the difference between an interval and another. + + The difference between two intervals are the points in the first + interval that are not shared with the second interval. + + Parameters + ---------- + other : Interval + Interval to which to calculate the difference. + + Returns + ------- + np.array + numpy array with two intervals if the second interval is + contained within the first. Array with one interval if + the difference only shortens the limits of the interval. + Empty array if the first interval is contained in the second + and thus there are no points left after difference. + + Examples + -------- + >>> i0 = pd.Interval(0, 3, closed='right') + >>> i1 = pd.Interval(2, 4, closed='right') + >>> i0.difference(i1) + array([Interval(0, 2, closed='right')], dtype=object) + + >>> i2 = pd.Interval(5, 8, closed='right') + >>> i0.intersection(i2) + array([Interval(0, 3, closed='right')], dtype=object) + + >>> i3 = pd.Interval(3, 5, closed='left') + >>> i0.difference(i3) + array([Interval(0, 3, closed='neither')], dtype=object) + + >>> i4 = pd.Interval(-2, 7, closed='left') + >>> i0.difference(i4) + array([], dtype=object) + + >>> i4.difference(i0) + array([Interval(-2, 0, closed='both') Interval(3, 7, closed='neither')], + dtype=object) + """ + if not isinstance(other, Interval): + raise TypeError("`other` must be an Interval, " + f"got {type(other).__name__}") + + # if there is no overlap then the difference is the interval + if not self.overlaps(other): + return np.array([self], dtype=object) + + # if the first interval is contained inside the other then there's no points + # left after the difference is applied + if self.left > other.left and self.right < other.right: + return np.array([], dtype=object) + + # if the intervals limits match but the other interval has closed limits then + # there are no points left after the difference is applied + if (self.left == other.left and self.right == other.right and + other.closed_left and other.closed_right): + return np.array([], dtype=object) + + # if the first interval contains the other then the difference is a union of + # two intervals + if self.left < other.left and self.right > other.right: + if self.closed_left and not other.closed_left: + closed1 = "both" + elif self.closed_left: + closed1 = "left" + elif not other.closed_left: + closed1 = "right" + else: + closed1 = "neither" + + if self.closed_right and not other.closed_right: + closed2 = "both" + elif self.closed_right: + closed2 = "right" + elif not other.closed_right: + closed2 = "left" + else: + closed2 = "neither" + + return np.array([Interval(self.left, other.left, closed1), + Interval(other.right, self.right, closed2)], + dtype=object) + + # Define left limit + if self.left < other.left: + dleft = self.left + lclosed = self.closed_left + elif self.left > other.left: + dleft = other.right + lclosed = not other.closed_right + else: + dleft = other.right if other.closed_left else self.left + lclosed = False if other.closed_left else self.closed_left + + # Define right limit + if self.right > other.right: + dright = self.right + rclosed = self.closed_right + elif self.right < other.right: + dright = other.left + rclosed = not other.closed_left + else: + dright = self.left if other.closed_right else other.right + rclosed = False if other.closed_right else self.closed_right + + # if the interval only contains one point then it must be closed + # on both sides + if dleft == dright: + if (lclosed and self.closed_left) or (rclosed and self.closed_right): + return np.array([Interval(dleft, dright, closed="both")], + dtype=object) + elif not (lclosed and rclosed): + return np.array([], dtype=object) + + if dleft > dright: + return np.array([], dtype=object) + + if lclosed and rclosed: + closed = "both" + elif lclosed: + closed = "left" + elif rclosed: + closed = "right" + else: + closed = "neither" + return np.array([Interval(dleft, dright, closed=closed)], dtype=object) + @cython.wraparound(False) @cython.boundscheck(False) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 52d64162358c8..c1ce94fd2baf1 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1380,6 +1380,201 @@ def overlaps(self, other): # (simplifying the negation allows this to be done in less operations) return op1(self.left, other.right) & op2(other.left, self.right) + _interval_shared_docs["intersection"] = textwrap.dedent( + """ + Calculates intersection between all intervals in the %(klass)s and a given + Interval. + + The intersection of two intervals is the common points shared between both, + including closed endpoints. Open endpoints are not included. + + Parameters + ---------- + other : Interval + Interval to which to calculate the intersection. + + Returns + ------- + array + Array containing the Intersections between each interval and other. + + See Also + -------- + Interval.intersection : Calculate intersection between two Interval objects. + + Examples + -------- + %(examples)s + + >>> intervals.intersection(pd.Interval(0, 2, 'right')) + array([Interval(0, 1, closed='right'), Interval(1, 2, closed='right'), + None], dtype=object) + + Intersection with a single value: + + >>> intervals.intersection(pd.Interval(5, 8, closed='left')) + array([None, Interval(5, 5, closed='both'), None], dtype=object) + """ + ) + + @Appender( + _interval_shared_docs["intersection"] + % { + "klass": "IntervalArray", + "examples": textwrap.dedent( + """\ + >>> data = [(0, 1), (1, 5), (2, 4)] + >>> intervals = pd.arrays.IntervalArray.from_tuples(data) + >>> intervals + + [(0, 1], (1, 5], (2, 4]] + Length: 3, dtype: interval[int64, right] + """ + ), + } + ) + def intersection(self, other): + if isinstance(other, (IntervalArray, ABCIntervalIndex)): + raise NotImplementedError + if not isinstance(other, Interval): + msg = f"`other` must be Interval-like, got {type(other).__name__}" + raise TypeError(msg) + + return np.array( + [interval.intersection(other) for interval in self], dtype=object + ) + + _interval_shared_docs["union"] = textwrap.dedent( + """ + Calculates union between each interval in the %(klass)s and a given + Interval. + + The union of two intervals are all the values in both, including + closed endpoints. + + Parameters + ---------- + other : Interval + Interval to which to calculate the union. + + Returns + ------- + array + Array of arrays containing the unions between each interval and other. + + See Also + -------- + Interval.union : Calculate union between two Interval objects. + + Examples + -------- + %(examples)s + + >>> intervals.union(pd.Interval(0, 2, 'right')) + array([[Interval(0, 2, closed='right')], + [Interval(0, 5, closed='right')], + [Interval(0, 4, closed='right')]], dtype=object) + + >>> intervals.union(pd.Interval(5, 8, closed='left')) + array([array([Interval(0, 1, closed='right'), Interval(5, 8, closed='left')], + dtype=object) , + array([Interval(1, 8, closed='neither')], dtype=object), + array([Interval(2, 4, closed='right'), Interval(5, 8, closed='left')], + dtype=object) ], + dtype=object) + """ + ) + + @Appender( + _interval_shared_docs["union"] + % { + "klass": "IntervalArray", + "examples": textwrap.dedent( + """\ + >>> data = [(0, 1), (1, 5), (2, 4)] + >>> intervals = pd.arrays.IntervalArray.from_tuples(data) + >>> intervals + + [(0, 1], (1, 5], (2, 4]] + Length: 3, dtype: interval[int64, right] + """ + ), + } + ) + def union(self, other): + if isinstance(other, (IntervalArray, ABCIntervalIndex)): + raise NotImplementedError + if not isinstance(other, Interval): + msg = f"`other` must be Interval-like, got {type(other).__name__}" + raise TypeError(msg) + + return np.array([interval.union(other) for interval in self], dtype=object) + + _interval_shared_docs["difference"] = textwrap.dedent( + """ + Calculates difference between each Interval in the %(klass)s and a given + Interval. + + The difference between two intervals are the points in the first + interval that are not shared with the second interval. + + Parameters + ---------- + other : Interval + Interval to which to calculate the difference. + + Returns + ------- + array + Array of arrays containing the differences between each interval and other. + + See Also + -------- + Interval.difference : Calculate difference between two Interval objects. + + Examples + -------- + %(examples)s + + >>> intervals.difference(pd.Interval(0, 2, 'right')) + array([[Interval(0, 0, closed='both')], + [Interval(2, 5, closed='right')], + [Interval(2, 4, closed='right')]], dtype=object) + + >>> intervals.difference(pd.Interval(2, 3, closed='left')) + array([array([Interval(0, 1, closed='right')], dtype=object), + array([Interval(1, 2, closed='neither'), Interval(3, 5, closed='both')], + dtype=object) , + array([Interval(3, 4, closed='right')], dtype=object)], + dtype=object) + """ + ) + + @Appender( + _interval_shared_docs["difference"] + % { + "klass": "IntervalArray", + "examples": textwrap.dedent( + """\ + >>> data = [(0, 1), (1, 5), (2, 4)] + >>> intervals = pd.arrays.IntervalArray.from_tuples(data) + >>> intervals + + [(0, 1], (1, 5], (2, 4]] + Length: 3, dtype: interval[int64, right] + """ + ), + } + ) + def difference(self, other): + if isinstance(other, (IntervalArray, ABCIntervalIndex)): + raise NotImplementedError + if not isinstance(other, Interval): + msg = f"`other` must be Interval-like, got {type(other).__name__}" + raise TypeError(msg) + + return np.array([interval.difference(other) for interval in self], dtype=object) + # --------------------------------------------------------------------- @property diff --git a/pandas/tests/arrays/interval/test_overlaps.py b/pandas/tests/arrays/interval/test_overlaps.py index 5a48cf024ec0d..fea04d75abea7 100644 --- a/pandas/tests/arrays/interval/test_overlaps.py +++ b/pandas/tests/arrays/interval/test_overlaps.py @@ -92,3 +92,132 @@ def test_overlaps_invalid_type(self, constructor, other): msg = f"`other` must be Interval-like, got {type(other).__name__}" with pytest.raises(TypeError, match=msg): interval_container.overlaps(other) + + +class TestIntersection: + def test_intersection_interval_array(self): + interval = Interval(1, 8, "left") + + tuples = [ # Intervals: + (1, 8), # identical + (2, 4), # nested + (0, 9), # spanning + (4, 10), # partial + (-5, 1), # adjacent closed + (8, 10), # adjacent open + (10, 15), # disjoint + ] + interval_container = IntervalArray.from_tuples(tuples, "both") + + expected = np.array( + [ + Interval(1, 8, "left"), + Interval(2, 4, "both"), + Interval(1, 8, "left"), + Interval(4, 8, "left"), + Interval(1, 1, "both"), + None, + None, + ] + ) + result = interval_container.intersection(interval) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")], + ids=lambda x: type(x).__name__, + ) + def test_intersection_invalid_type(self, other): + interval_container = IntervalArray.from_breaks(range(5)) + msg = f"`other` must be Interval-like, got {type(other).__name__}" + with pytest.raises(TypeError, match=msg): + interval_container.intersection(other) + + +class TestUnion: + def test_union_interval_array(self): + interval = Interval(1, 8, "left") + + tuples = [ # Intervals: + (1, 8), # identical + (2, 4), # nested + (0, 9), # spanning + (4, 10), # partial + (-5, 1), # adjacent closed + (8, 10), # adjacent open + (10, 15), # disjoint + ] + interval_container = IntervalArray.from_tuples(tuples, "both") + + expected = np.array( + [ + np.array([Interval(1, 8, "both")], dtype=object), + np.array([Interval(1, 8, "left")], dtype=object), + np.array([Interval(0, 9, "both")], dtype=object), + np.array([Interval(1, 10, "both")], dtype=object), + np.array([Interval(-5, 8, "left")], dtype=object), + np.array([Interval(1, 10, "both")], dtype=object), + np.array( + [Interval(1, 8, "left"), Interval(10, 15, "both")], dtype=object + ), + ], + dtype=object, + ) + result = interval_container.union(interval) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")], + ids=lambda x: type(x).__name__, + ) + def test_union_invalid_type(self, other): + interval_container = IntervalArray.from_breaks(range(5)) + msg = f"`other` must be Interval-like, got {type(other).__name__}" + with pytest.raises(TypeError, match=msg): + interval_container.union(other) + + +class TestDifference: + def test_difference_interval_array(self): + interval = Interval(1, 8, "left") + + tuples = [ # Intervals: + (1, 8), # identical + (2, 4), # nested + (0, 9), # spanning + (4, 10), # partial + (-5, 1), # adjacent closed + (8, 10), # adjacent open + (10, 15), # disjoint + ] + interval_container = IntervalArray.from_tuples(tuples, "both") + + expected = np.array( + [ + np.array([Interval(8, 8, "both")], dtype=object), + np.array([], dtype=object), + np.array( + [Interval(0, 1, "left"), Interval(8, 9, "both")], dtype=object + ), + np.array([Interval(8, 10, "both")], dtype=object), + np.array([Interval(-5, 1, "left")], dtype=object), + np.array([Interval(8, 10, "both")], dtype=object), + np.array([Interval(10, 15, "both")], dtype=object), + ], + dtype=object, + ) + result = interval_container.difference(interval) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")], + ids=lambda x: type(x).__name__, + ) + def test_difference_invalid_type(self, other): + interval_container = IntervalArray.from_breaks(range(5)) + msg = f"`other` must be Interval-like, got {type(other).__name__}" + with pytest.raises(TypeError, match=msg): + interval_container.difference(other) diff --git a/pandas/tests/scalar/interval/test_overlaps.py b/pandas/tests/scalar/interval/test_overlaps.py index 7fcf59d7bb4af..030ff03961443 100644 --- a/pandas/tests/scalar/interval/test_overlaps.py +++ b/pandas/tests/scalar/interval/test_overlaps.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from pandas import ( @@ -5,6 +6,7 @@ Timedelta, Timestamp, ) +import pandas._testing as tm @pytest.fixture( @@ -65,3 +67,353 @@ def test_overlaps_invalid_type(self, other): msg = f"`other` must be an Interval, got {type(other).__name__}" with pytest.raises(TypeError, match=msg): interval.overlaps(other) + + +class TestIntersection: + def test_intersection_self(self): + interval = Interval(1, 8, "left") + assert interval.intersection(interval) == interval + + def test_intersection_include_limits(self): + other = Interval(1, 8, "left") + + intervals = np.array( + [ + Interval(7, 9, "left"), # include left + Interval(0, 2, "right"), # include right + Interval(1, 8, "right"), # open limit + ] + ) + + expected = np.array( + [ + Interval(7, 8, "left"), + Interval(1, 2, "both"), + Interval(1, 8, "neither"), + ] + ) + + result = np.array([interval.intersection(other) for interval in intervals]) + tm.assert_numpy_array_equal(result, expected) + + def test_intersection_overlapping(self): + other = Interval(1, 8, "left") + + intervals = np.array( + [ + Interval(2, 4, "both"), # nested + Interval(0, 9, "both"), # spanning + Interval(4, 10, "both"), # partial + ] + ) + + expected = np.array( + [ + Interval(2, 4, "both"), + Interval(1, 8, "left"), + Interval(4, 8, "left"), + ] + ) + + result = np.array([interval.intersection(other) for interval in intervals]) + tm.assert_numpy_array_equal(result, expected) + + def test_intersection_adjacent(self): + other = Interval(1, 8, "left") + + intervals = np.array( + [ + Interval(-5, 1, "both"), # adjacent closed + Interval(8, 10, "both"), # adjacent open + Interval(10, 15, "both"), # disjoint + ] + ) + + expected = np.array( + [ + Interval(1, 1, "both"), + None, + None, + ] + ) + + result = np.array([interval.intersection(other) for interval in intervals]) + tm.assert_numpy_array_equal(result, expected) + + def test_intersection_timestamps(self): + year_2020 = Interval( + Timestamp("2020-01-01 00:00:00"), + Timestamp("2021-01-01 00:00:00"), + closed="left", + ) + + march_2020 = Interval( + Timestamp("2020-03-01 00:00:00"), + Timestamp("2020-04-01 00:00:00"), + closed="left", + ) + + result = year_2020.intersection(march_2020) + assert result == march_2020 + + @pytest.mark.parametrize( + "other", + [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")], + ids=lambda x: type(x).__name__, + ) + def test_intersection_invalid_type(self, other): + interval = Interval(0, 1) + msg = f"`other` must be an Interval, got {type(other).__name__}" + with pytest.raises(TypeError, match=msg): + interval.intersection(other) + + +class TestUnion: + def test_union_self(self): + interval = Interval(1, 8, "left") + + result = interval.union(interval) + + expected = np.array([interval], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_union_include_limits(self): + other = Interval(1, 8, "left") + + intervals = np.array( + [ + Interval(7, 9, "left"), # include left + Interval(0, 2, "right"), # include right + Interval(1, 8, "right"), # open limit + ] + ) + + expected = np.array( + [ + np.array([Interval(1, 9, "left")], dtype=object), + np.array([Interval(0, 8, "neither")], dtype=object), + np.array([Interval(1, 8, "both")], dtype=object), + ], + dtype=object, + ) + + result = np.array([interval.union(other) for interval in intervals]) + tm.assert_numpy_array_equal(result, expected) + + def test_union_overlapping(self): + other = Interval(1, 8, "left") + + intervals = np.array( + [ + Interval(2, 4, "both"), # nested + Interval(0, 9, "both"), # spanning + Interval(4, 10, "both"), # partial + ] + ) + + expected = np.array( + [ + np.array([Interval(1, 8, "left")], dtype=object), + np.array([Interval(0, 9, "both")], dtype=object), + np.array([Interval(1, 10, "both")], dtype=object), + ], + dtype=object, + ) + + result = np.array([interval.union(other) for interval in intervals]) + tm.assert_numpy_array_equal(result, expected) + + def test_union_adjacent(self): + other = Interval(1, 8, "left") + + intervals = np.array( + [ + Interval(-5, 1, "both"), # adjacent closed + Interval(8, 10, "both"), # adjacent open + Interval(10, 15, "both"), # disjoint + ] + ) + + expected = np.array( + [ + np.array([Interval(-5, 8, "left")], dtype=object), + np.array([Interval(1, 10, "both")], dtype=object), + np.array([other, Interval(10, 15, "both")], dtype=object), + ], + dtype=object, + ) + + result = np.array( + [interval.union(other) for interval in intervals], dtype=object + ) + tm.assert_numpy_array_equal(result, expected) + + def test_union_timestamps(self): + year_2020 = Interval( + Timestamp("2020-01-01 00:00:00"), + Timestamp("2021-01-01 00:00:00"), + closed="left", + ) + + year_2021 = Interval( + Timestamp("2021-01-01 00:00:00"), + Timestamp("2022-01-01 00:00:00"), + closed="left", + ) + + expected = np.array( + [ + Interval( + Timestamp("2020-01-01 00:00:00"), + Timestamp("2022-01-01 00:00:00"), + closed="left", + ) + ], + dtype=object, + ) + + result = year_2020.union(year_2021) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")], + ids=lambda x: type(x).__name__, + ) + def test_union_invalid_type(self, other): + interval = Interval(0, 1) + msg = f"`other` must be an Interval, got {type(other).__name__}" + with pytest.raises(TypeError, match=msg): + interval.union(other) + + +class TestDifference: + def test_difference_self(self): + interval = Interval(1, 8, "left") + + result = interval.difference(interval) + + expected = np.array([], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_difference_include_limits(self): + interval = Interval(1, 8, "left") + + others = np.array( + [ + Interval(7, 9, "left"), # include right + Interval(0, 2, "right"), # include left + Interval(1, 8, "right"), # open limit + ] + ) + + expected = np.array( + [ + np.array([Interval(1, 7, "left")], dtype=object), + np.array([Interval(2, 8, "neither")], dtype=object), + np.array([Interval(1, 1, "both")], dtype=object), + ], + dtype=object, + ) + + result = np.array([interval.difference(other) for other in others]) + tm.assert_numpy_array_equal(result, expected) + + def test_difference_overlapping(self): + interval = Interval(1, 8, "left") + + others = np.array( + [ + Interval(2, 4, "both"), # nested + Interval(0, 9, "both"), # spanning + Interval(4, 10, "both"), # partial + Interval(0, 8, "both"), # extends left + Interval(1, 9, "both"), # extends right + ] + ) + + expected = np.array( + [ + np.array( + [Interval(1, 2, "left"), Interval(4, 8, "neither")], dtype=object + ), + np.array([], dtype=object), + np.array([Interval(1, 4, "left")], dtype=object), + np.array([], dtype=object), + np.array([], dtype=object), + ], + dtype=object, + ) + + result = np.array( + [interval.difference(other) for other in others], dtype=object + ) + tm.assert_numpy_array_equal(result, expected) + + def test_difference_adjacent(self): + interval = Interval(1, 8, "left") + + others = np.array( + [ + Interval(-5, 1, "both"), # adjacent closed + Interval(8, 10, "both"), # adjacent open + Interval(10, 15, "both"), # disjoint + ] + ) + + expected = np.array( + [ + np.array([Interval(1, 8, "neither")], dtype=object), + np.array([interval], dtype=object), + np.array([interval], dtype=object), + ], + dtype=object, + ) + + result = np.array( + [interval.difference(other) for other in others], dtype=object + ) + tm.assert_numpy_array_equal(result, expected) + + def test_difference_timestamps(self): + year_2020 = Interval( + Timestamp("2020-01-01 00:00:00"), + Timestamp("2021-01-01 00:00:00"), + closed="left", + ) + + march_2020 = Interval( + Timestamp("2020-03-01 00:00:00"), + Timestamp("2020-04-01 00:00:00"), + closed="left", + ) + + expected = np.array( + [ + Interval( + Timestamp("2020-01-01 00:00:00"), + Timestamp("2020-03-01 00:00:00"), + closed="left", + ), + Interval( + Timestamp("2020-04-01 00:00:00"), + Timestamp("2021-01-01 00:00:00"), + closed="left", + ), + ], + dtype=object, + ) + + result = year_2020.difference(march_2020) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")], + ids=lambda x: type(x).__name__, + ) + def test_difference_invalid_type(self, other): + interval = Interval(0, 1) + msg = f"`other` must be an Interval, got {type(other).__name__}" + with pytest.raises(TypeError, match=msg): + interval.difference(other)