1
+ import operator
2
+
1
3
import numpy as np
2
4
import pytest
3
5
6
+ from pandas .core .dtypes .common import is_list_like
7
+
4
8
import pandas as pd
5
9
from pandas import (
10
+ Categorical ,
6
11
Index ,
7
12
Interval ,
8
13
IntervalIndex ,
14
+ Period ,
15
+ Series ,
9
16
Timedelta ,
10
17
Timestamp ,
11
18
date_range ,
19
+ period_range ,
12
20
timedelta_range ,
13
21
)
14
22
import pandas ._testing as tm
@@ -35,6 +43,18 @@ def left_right_dtypes(request):
35
43
return request .param
36
44
37
45
46
+ def create_categorical_intervals (left , right , closed = "right" ):
47
+ return Categorical (IntervalIndex .from_arrays (left , right , closed ))
48
+
49
+
50
+ def create_series_intervals (left , right , closed = "right" ):
51
+ return Series (IntervalArray .from_arrays (left , right , closed ))
52
+
53
+
54
+ def create_series_categorical_intervals (left , right , closed = "right" ):
55
+ return Series (Categorical (IntervalIndex .from_arrays (left , right , closed )))
56
+
57
+
38
58
class TestAttributes :
39
59
@pytest .mark .parametrize (
40
60
"left, right" ,
@@ -93,6 +113,221 @@ def test_set_na(self, left_right_dtypes):
93
113
tm .assert_extension_array_equal (result , expected )
94
114
95
115
116
+ class TestComparison :
117
+ @pytest .fixture (params = [operator .eq , operator .ne ])
118
+ def op (self , request ):
119
+ return request .param
120
+
121
+ @pytest .fixture
122
+ def array (self , left_right_dtypes ):
123
+ """
124
+ Fixture to generate an IntervalArray of various dtypes containing NA if possible
125
+ """
126
+ left , right = left_right_dtypes
127
+ if left .dtype != "int64" :
128
+ left , right = left .insert (4 , np .nan ), right .insert (4 , np .nan )
129
+ else :
130
+ left , right = left .insert (4 , 10 ), right .insert (4 , 20 )
131
+ return IntervalArray .from_arrays (left , right )
132
+
133
+ @pytest .fixture (
134
+ params = [
135
+ IntervalArray .from_arrays ,
136
+ IntervalIndex .from_arrays ,
137
+ create_categorical_intervals ,
138
+ create_series_intervals ,
139
+ create_series_categorical_intervals ,
140
+ ],
141
+ ids = [
142
+ "IntervalArray" ,
143
+ "IntervalIndex" ,
144
+ "Categorical[Interval]" ,
145
+ "Series[Interval]" ,
146
+ "Series[Categorical[Interval]]" ,
147
+ ],
148
+ )
149
+ def interval_constructor (self , request ):
150
+ """
151
+ Fixture for all pandas native interval constructors.
152
+ To be used as the LHS of IntervalArray comparisons.
153
+ """
154
+ return request .param
155
+
156
+ def elementwise_comparison (self , op , array , other ):
157
+ """
158
+ Helper that performs elementwise comparisions between `array` and `other`
159
+ """
160
+ other = other if is_list_like (other ) else [other ] * len (array )
161
+ return np .array ([op (x , y ) for x , y in zip (array , other )])
162
+
163
+ def test_compare_scalar_interval (self , op , array ):
164
+ # matches first interval
165
+ other = array [0 ]
166
+ result = op (array , other )
167
+ expected = self .elementwise_comparison (op , array , other )
168
+ tm .assert_numpy_array_equal (result , expected )
169
+
170
+ # matches on a single endpoint but not both
171
+ other = Interval (array .left [0 ], array .right [1 ])
172
+ result = op (array , other )
173
+ expected = self .elementwise_comparison (op , array , other )
174
+ tm .assert_numpy_array_equal (result , expected )
175
+
176
+ def test_compare_scalar_interval_mixed_closed (self , op , closed , other_closed ):
177
+ array = IntervalArray .from_arrays (range (2 ), range (1 , 3 ), closed = closed )
178
+ other = Interval (0 , 1 , closed = other_closed )
179
+
180
+ result = op (array , other )
181
+ expected = self .elementwise_comparison (op , array , other )
182
+ tm .assert_numpy_array_equal (result , expected )
183
+
184
+ def test_compare_scalar_na (self , op , array , nulls_fixture ):
185
+ result = op (array , nulls_fixture )
186
+ expected = self .elementwise_comparison (op , array , nulls_fixture )
187
+ tm .assert_numpy_array_equal (result , expected )
188
+
189
+ @pytest .mark .parametrize (
190
+ "other" ,
191
+ [
192
+ 0 ,
193
+ 1.0 ,
194
+ True ,
195
+ "foo" ,
196
+ Timestamp ("2017-01-01" ),
197
+ Timestamp ("2017-01-01" , tz = "US/Eastern" ),
198
+ Timedelta ("0 days" ),
199
+ Period ("2017-01-01" , "D" ),
200
+ ],
201
+ )
202
+ def test_compare_scalar_other (self , op , array , other ):
203
+ result = op (array , other )
204
+ expected = self .elementwise_comparison (op , array , other )
205
+ tm .assert_numpy_array_equal (result , expected )
206
+
207
+ def test_compare_list_like_interval (
208
+ self , op , array , interval_constructor ,
209
+ ):
210
+ # same endpoints
211
+ other = interval_constructor (array .left , array .right )
212
+ result = op (array , other )
213
+ expected = self .elementwise_comparison (op , array , other )
214
+ tm .assert_numpy_array_equal (result , expected )
215
+
216
+ # different endpoints
217
+ other = interval_constructor (array .left [::- 1 ], array .right [::- 1 ])
218
+ result = op (array , other )
219
+ expected = self .elementwise_comparison (op , array , other )
220
+ tm .assert_numpy_array_equal (result , expected )
221
+
222
+ # all nan endpoints
223
+ other = interval_constructor ([np .nan ] * 4 , [np .nan ] * 4 )
224
+ result = op (array , other )
225
+ expected = self .elementwise_comparison (op , array , other )
226
+ tm .assert_numpy_array_equal (result , expected )
227
+
228
+ def test_compare_list_like_interval_mixed_closed (
229
+ self , op , interval_constructor , closed , other_closed
230
+ ):
231
+ array = IntervalArray .from_arrays (range (2 ), range (1 , 3 ), closed = closed )
232
+ other = interval_constructor (range (2 ), range (1 , 3 ), closed = other_closed )
233
+
234
+ result = op (array , other )
235
+ expected = self .elementwise_comparison (op , array , other )
236
+ tm .assert_numpy_array_equal (result , expected )
237
+
238
+ @pytest .mark .parametrize (
239
+ "other" ,
240
+ [
241
+ (
242
+ Interval (0 , 1 ),
243
+ Interval (Timedelta ("1 day" ), Timedelta ("2 days" )),
244
+ Interval (4 , 5 , "both" ),
245
+ Interval (10 , 20 , "neither" ),
246
+ ),
247
+ (0 , 1.5 , Timestamp ("20170103" ), np .nan ),
248
+ (
249
+ Timestamp ("20170102" , tz = "US/Eastern" ),
250
+ Timedelta ("2 days" ),
251
+ "baz" ,
252
+ pd .NaT ,
253
+ ),
254
+ ],
255
+ )
256
+ def test_compare_list_like_object (self , op , array , other ):
257
+ result = op (array , other )
258
+ expected = self .elementwise_comparison (op , array , other )
259
+ tm .assert_numpy_array_equal (result , expected )
260
+
261
+ def test_compare_list_like_nan (self , op , array , nulls_fixture ):
262
+ other = [nulls_fixture ] * 4
263
+ result = op (array , other )
264
+ expected = self .elementwise_comparison (op , array , other )
265
+ tm .assert_numpy_array_equal (result , expected )
266
+
267
+ @pytest .mark .parametrize (
268
+ "other" ,
269
+ [
270
+ np .arange (4 , dtype = "int64" ),
271
+ np .arange (4 , dtype = "float64" ),
272
+ date_range ("2017-01-01" , periods = 4 ),
273
+ date_range ("2017-01-01" , periods = 4 , tz = "US/Eastern" ),
274
+ timedelta_range ("0 days" , periods = 4 ),
275
+ period_range ("2017-01-01" , periods = 4 , freq = "D" ),
276
+ Categorical (list ("abab" )),
277
+ Categorical (date_range ("2017-01-01" , periods = 4 )),
278
+ pd .array (list ("abcd" )),
279
+ pd .array (["foo" , 3.14 , None , object ()]),
280
+ ],
281
+ ids = lambda x : str (x .dtype ),
282
+ )
283
+ def test_compare_list_like_other (self , op , array , other ):
284
+ result = op (array , other )
285
+ expected = self .elementwise_comparison (op , array , other )
286
+ tm .assert_numpy_array_equal (result , expected )
287
+
288
+ @pytest .mark .parametrize ("length" , [1 , 3 , 5 ])
289
+ @pytest .mark .parametrize ("other_constructor" , [IntervalArray , list ])
290
+ def test_compare_length_mismatch_errors (self , op , other_constructor , length ):
291
+ array = IntervalArray .from_arrays (range (4 ), range (1 , 5 ))
292
+ other = other_constructor ([Interval (0 , 1 )] * length )
293
+ with pytest .raises (ValueError , match = "Lengths must match to compare" ):
294
+ op (array , other )
295
+
296
+ @pytest .mark .parametrize (
297
+ "constructor, expected_type, assert_func" ,
298
+ [
299
+ (IntervalIndex , np .array , tm .assert_numpy_array_equal ),
300
+ (Series , Series , tm .assert_series_equal ),
301
+ ],
302
+ )
303
+ def test_index_series_compat (self , op , constructor , expected_type , assert_func ):
304
+ # IntervalIndex/Series that rely on IntervalArray for comparisons
305
+ breaks = range (4 )
306
+ index = constructor (IntervalIndex .from_breaks (breaks ))
307
+
308
+ # scalar comparisons
309
+ other = index [0 ]
310
+ result = op (index , other )
311
+ expected = expected_type (self .elementwise_comparison (op , index , other ))
312
+ assert_func (result , expected )
313
+
314
+ other = breaks [0 ]
315
+ result = op (index , other )
316
+ expected = expected_type (self .elementwise_comparison (op , index , other ))
317
+ assert_func (result , expected )
318
+
319
+ # list-like comparisons
320
+ other = IntervalArray .from_breaks (breaks )
321
+ result = op (index , other )
322
+ expected = expected_type (self .elementwise_comparison (op , index , other ))
323
+ assert_func (result , expected )
324
+
325
+ other = [index [0 ], breaks [0 ], "foo" ]
326
+ result = op (index , other )
327
+ expected = expected_type (self .elementwise_comparison (op , index , other ))
328
+ assert_func (result , expected )
329
+
330
+
96
331
def test_repr ():
97
332
# GH 25022
98
333
arr = IntervalArray .from_tuples ([(0 , 1 ), (1 , 2 )])
0 commit comments