|
| 1 | +import numpy as np |
1 | 2 | import pytest
|
2 | 3 |
|
3 | 4 | from pandas import (
|
4 | 5 | Interval,
|
5 | 6 | Timedelta,
|
6 | 7 | Timestamp,
|
7 | 8 | )
|
| 9 | +import pandas._testing as tm |
8 | 10 |
|
9 | 11 |
|
10 | 12 | @pytest.fixture(
|
@@ -65,3 +67,102 @@ def test_overlaps_invalid_type(self, other):
|
65 | 67 | msg = f"`other` must be an Interval, got {type(other).__name__}"
|
66 | 68 | with pytest.raises(TypeError, match=msg):
|
67 | 69 | interval.overlaps(other)
|
| 70 | + |
| 71 | + |
| 72 | +class TestIntersection: |
| 73 | + def test_intersection_self(self): |
| 74 | + interval = Interval(1, 8, "left") |
| 75 | + assert interval.intersection(interval) == interval |
| 76 | + |
| 77 | + def test_intersection_include_limits(self): |
| 78 | + other = Interval(1, 8, "left") |
| 79 | + |
| 80 | + intervals = np.array( |
| 81 | + [ |
| 82 | + Interval(7, 9, "left"), # include left |
| 83 | + Interval(0, 2, "right"), # include right |
| 84 | + Interval(1, 8, "right"), # open limit |
| 85 | + ] |
| 86 | + ) |
| 87 | + |
| 88 | + expected = np.array( |
| 89 | + [ |
| 90 | + Interval(7, 8, "left"), |
| 91 | + Interval(1, 2, "both"), |
| 92 | + Interval(1, 8, "neither"), |
| 93 | + ] |
| 94 | + ) |
| 95 | + |
| 96 | + result = np.array([interval.intersection(other) for interval in intervals]) |
| 97 | + tm.assert_numpy_array_equal(result, expected) |
| 98 | + |
| 99 | + def test_intersection_overlapping(self): |
| 100 | + other = Interval(1, 8, "left") |
| 101 | + |
| 102 | + intervals = np.array( |
| 103 | + [ |
| 104 | + Interval(2, 4, "both"), # nested |
| 105 | + Interval(0, 9, "both"), # spanning |
| 106 | + Interval(4, 10, "both"), # partial |
| 107 | + ] |
| 108 | + ) |
| 109 | + |
| 110 | + expected = np.array( |
| 111 | + [ |
| 112 | + Interval(2, 4, "both"), |
| 113 | + Interval(1, 8, "left"), |
| 114 | + Interval(4, 8, "left"), |
| 115 | + ] |
| 116 | + ) |
| 117 | + |
| 118 | + result = np.array([interval.intersection(other) for interval in intervals]) |
| 119 | + tm.assert_numpy_array_equal(result, expected) |
| 120 | + |
| 121 | + def test_intersection_adjacent(self): |
| 122 | + other = Interval(1, 8, "left") |
| 123 | + |
| 124 | + intervals = np.array( |
| 125 | + [ |
| 126 | + Interval(-5, 1, "both"), # adjacent closed |
| 127 | + Interval(8, 10, "both"), # adjacent open |
| 128 | + Interval(10, 15, "both"), # disjoint |
| 129 | + ] |
| 130 | + ) |
| 131 | + |
| 132 | + expected = np.array( |
| 133 | + [ |
| 134 | + Interval(1, 1, "both"), |
| 135 | + None, |
| 136 | + None, |
| 137 | + ] |
| 138 | + ) |
| 139 | + |
| 140 | + result = np.array([interval.intersection(other) for interval in intervals]) |
| 141 | + tm.assert_numpy_array_equal(result, expected) |
| 142 | + |
| 143 | + def test_intersection_timestamps(self): |
| 144 | + year_2020 = Interval( |
| 145 | + Timestamp("2020-01-01 00:00:00"), |
| 146 | + Timestamp("2021-01-01 00:00:00"), |
| 147 | + closed="left", |
| 148 | + ) |
| 149 | + |
| 150 | + march_2020 = Interval( |
| 151 | + Timestamp("2020-03-01 00:00:00"), |
| 152 | + Timestamp("2020-04-01 00:00:00"), |
| 153 | + closed="left", |
| 154 | + ) |
| 155 | + |
| 156 | + result = year_2020.intersection(march_2020) |
| 157 | + assert result == march_2020 |
| 158 | + |
| 159 | + @pytest.mark.parametrize( |
| 160 | + "other", |
| 161 | + [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")], |
| 162 | + ids=lambda x: type(x).__name__, |
| 163 | + ) |
| 164 | + def test_intersection_invalid_type(self, other): |
| 165 | + interval = Interval(0, 1) |
| 166 | + msg = f"`other` must be an Interval, got {type(other).__name__}" |
| 167 | + with pytest.raises(TypeError, match=msg): |
| 168 | + interval.intersection(other) |
0 commit comments