@@ -81,6 +81,41 @@ def test_compare_len1_raises(self):
81
81
with pytest .raises (ValueError , match = "Lengths must match" ):
82
82
idx <= idx [[0 ]]
83
83
84
+ @pytest .mark .parametrize ("reverse" , [True , False ])
85
+ @pytest .mark .parametrize ("as_index" , [True , False ])
86
+ def test_compare_categorical_dtype (self , arr1d , as_index , reverse , ordered ):
87
+ other = pd .Categorical (arr1d , ordered = ordered )
88
+ if as_index :
89
+ other = pd .CategoricalIndex (other )
90
+
91
+ left , right = arr1d , other
92
+ if reverse :
93
+ left , right = right , left
94
+
95
+ ones = np .ones (arr1d .shape , dtype = bool )
96
+ zeros = ~ ones
97
+
98
+ result = left == right
99
+ tm .assert_numpy_array_equal (result , ones )
100
+
101
+ result = left != right
102
+ tm .assert_numpy_array_equal (result , zeros )
103
+
104
+ if not reverse and not as_index :
105
+ # Otherwise Categorical raises TypeError bc it is not ordered
106
+ # TODO: we should probably get the same behavior regardless?
107
+ result = left < right
108
+ tm .assert_numpy_array_equal (result , zeros )
109
+
110
+ result = left <= right
111
+ tm .assert_numpy_array_equal (result , ones )
112
+
113
+ result = left > right
114
+ tm .assert_numpy_array_equal (result , zeros )
115
+
116
+ result = left >= right
117
+ tm .assert_numpy_array_equal (result , ones )
118
+
84
119
def test_take (self ):
85
120
data = np .arange (100 , dtype = "i8" ) * 24 * 3600 * 10 ** 9
86
121
np .random .shuffle (data )
@@ -251,6 +286,20 @@ def test_setitem_str_array(self, arr1d):
251
286
252
287
tm .assert_equal (arr1d , expected )
253
288
289
+ @pytest .mark .parametrize ("as_index" , [True , False ])
290
+ def test_setitem_categorical (self , arr1d , as_index ):
291
+ expected = arr1d .copy ()[::- 1 ]
292
+ if not isinstance (expected , PeriodArray ):
293
+ expected = expected ._with_freq (None )
294
+
295
+ cat = pd .Categorical (arr1d )
296
+ if as_index :
297
+ cat = pd .CategoricalIndex (cat )
298
+
299
+ arr1d [:] = cat [::- 1 ]
300
+
301
+ tm .assert_equal (arr1d , expected )
302
+
254
303
def test_setitem_raises (self ):
255
304
data = np .arange (10 , dtype = "i8" ) * 24 * 3600 * 10 ** 9
256
305
arr = self .array_cls (data , freq = "D" )
@@ -924,6 +973,7 @@ def test_to_numpy_extra(array):
924
973
tm .assert_equal (array , original )
925
974
926
975
976
+ @pytest .mark .parametrize ("as_index" , [True , False ])
927
977
@pytest .mark .parametrize (
928
978
"values" ,
929
979
[
@@ -932,9 +982,23 @@ def test_to_numpy_extra(array):
932
982
pd .PeriodIndex (["2020-01-01" , "2020-02-01" ], freq = "D" ),
933
983
],
934
984
)
935
- @pytest .mark .parametrize ("klass" , [list , np .array , pd .array , pd .Series ])
936
- def test_searchsorted_datetimelike_with_listlike (values , klass ):
985
+ @pytest .mark .parametrize (
986
+ "klass" ,
987
+ [
988
+ list ,
989
+ np .array ,
990
+ pd .array ,
991
+ pd .Series ,
992
+ pd .Index ,
993
+ pd .Categorical ,
994
+ pd .CategoricalIndex ,
995
+ ],
996
+ )
997
+ def test_searchsorted_datetimelike_with_listlike (values , klass , as_index ):
937
998
# https://github.com/pandas-dev/pandas/issues/32762
999
+ if not as_index :
1000
+ values = values ._data
1001
+
938
1002
result = values .searchsorted (klass (values ))
939
1003
expected = np .array ([0 , 1 ], dtype = result .dtype )
940
1004
0 commit comments