Skip to content

Commit 3ed7dff

Browse files
authored
BUG: datetimelike/categorical comparisons, standardize behavior (#34055)
1 parent fa3662c commit 3ed7dff

File tree

3 files changed

+86
-20
lines changed

3 files changed

+86
-20
lines changed

pandas/core/arrays/categorical.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@
4545
from pandas.core.algorithms import _get_data_algo, factorize, take_1d, unique1d
4646
from pandas.core.array_algos.transforms import shift
4747
from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray
48-
from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs
48+
from pandas.core.base import (
49+
ExtensionArray,
50+
NoNewAttributesMixin,
51+
PandasObject,
52+
_shared_docs,
53+
)
4954
import pandas.core.common as com
5055
from pandas.core.construction import array, extract_array, sanitize_array
5156
from pandas.core.indexers import check_array_indexer, deprecate_ndim_indexing
@@ -124,17 +129,20 @@ def func(self, other):
124129
"scalar, which is not a category."
125130
)
126131
else:
127-
128132
# allow categorical vs object dtype array comparisons for equality
129133
# these are only positional comparisons
130-
if opname in ["__eq__", "__ne__"]:
131-
return getattr(np.array(self), opname)(np.array(other))
134+
if opname not in ["__eq__", "__ne__"]:
135+
raise TypeError(
136+
f"Cannot compare a Categorical for op {opname} with "
137+
f"type {type(other)}.\nIf you want to compare values, "
138+
"use 'np.asarray(cat) <op> other'."
139+
)
132140

133-
raise TypeError(
134-
f"Cannot compare a Categorical for op {opname} with "
135-
f"type {type(other)}.\nIf you want to compare values, "
136-
"use 'np.asarray(cat) <op> other'."
137-
)
141+
if isinstance(other, ExtensionArray) and needs_i8_conversion(other):
142+
# We would return NotImplemented here, but that messes up
143+
# ExtensionIndex's wrapped methods
144+
return op(other, self)
145+
return getattr(np.array(self), opname)(np.array(other))
138146

139147
func.__name__ = opname
140148

pandas/core/arrays/datetimelike.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def _validate_comparison_value(self, other):
9898

9999
@unpack_zerodim_and_defer(opname)
100100
def wrapper(self, other):
101-
102101
try:
103102
other = _validate_comparison_value(self, other)
104103
except InvalidComparison:
@@ -762,12 +761,7 @@ def _validate_shift_value(self, fill_value):
762761
return self._unbox(fill_value)
763762

764763
def _validate_listlike(
765-
self,
766-
value,
767-
opname: str,
768-
cast_str: bool = False,
769-
cast_cat: bool = False,
770-
allow_object: bool = False,
764+
self, value, opname: str, cast_str: bool = False, allow_object: bool = False,
771765
):
772766
if isinstance(value, type(self)):
773767
return value
@@ -786,7 +780,7 @@ def _validate_listlike(
786780
except ValueError:
787781
pass
788782

789-
if cast_cat and is_categorical_dtype(value.dtype):
783+
if is_categorical_dtype(value.dtype):
790784
# e.g. we have a Categorical holding self.dtype
791785
if is_dtype_equal(value.categories.dtype, self.dtype):
792786
# TODO: do we need equal dtype or just comparable?
@@ -871,7 +865,7 @@ def _validate_where_value(self, other):
871865
raise TypeError(f"Where requires matching dtype, not {type(other)}")
872866

873867
else:
874-
other = self._validate_listlike(other, "where", cast_cat=True)
868+
other = self._validate_listlike(other, "where")
875869
self._check_compatible_with(other, setitem=True)
876870

877871
self._check_compatible_with(other, setitem=True)

pandas/tests/arrays/test_datetimelike.py

+66-2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,41 @@ def test_compare_len1_raises(self):
8181
with pytest.raises(ValueError, match="Lengths must match"):
8282
idx <= idx[[0]]
8383

84+
@pytest.mark.parametrize("reverse", [True, False])
85+
@pytest.mark.parametrize("as_index", [True, False])
86+
def test_compare_categorical_dtype(self, arr1d, as_index, reverse, ordered):
87+
other = pd.Categorical(arr1d, ordered=ordered)
88+
if as_index:
89+
other = pd.CategoricalIndex(other)
90+
91+
left, right = arr1d, other
92+
if reverse:
93+
left, right = right, left
94+
95+
ones = np.ones(arr1d.shape, dtype=bool)
96+
zeros = ~ones
97+
98+
result = left == right
99+
tm.assert_numpy_array_equal(result, ones)
100+
101+
result = left != right
102+
tm.assert_numpy_array_equal(result, zeros)
103+
104+
if not reverse and not as_index:
105+
# Otherwise Categorical raises TypeError bc it is not ordered
106+
# TODO: we should probably get the same behavior regardless?
107+
result = left < right
108+
tm.assert_numpy_array_equal(result, zeros)
109+
110+
result = left <= right
111+
tm.assert_numpy_array_equal(result, ones)
112+
113+
result = left > right
114+
tm.assert_numpy_array_equal(result, zeros)
115+
116+
result = left >= right
117+
tm.assert_numpy_array_equal(result, ones)
118+
84119
def test_take(self):
85120
data = np.arange(100, dtype="i8") * 24 * 3600 * 10 ** 9
86121
np.random.shuffle(data)
@@ -251,6 +286,20 @@ def test_setitem_str_array(self, arr1d):
251286

252287
tm.assert_equal(arr1d, expected)
253288

289+
@pytest.mark.parametrize("as_index", [True, False])
290+
def test_setitem_categorical(self, arr1d, as_index):
291+
expected = arr1d.copy()[::-1]
292+
if not isinstance(expected, PeriodArray):
293+
expected = expected._with_freq(None)
294+
295+
cat = pd.Categorical(arr1d)
296+
if as_index:
297+
cat = pd.CategoricalIndex(cat)
298+
299+
arr1d[:] = cat[::-1]
300+
301+
tm.assert_equal(arr1d, expected)
302+
254303
def test_setitem_raises(self):
255304
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
256305
arr = self.array_cls(data, freq="D")
@@ -924,6 +973,7 @@ def test_to_numpy_extra(array):
924973
tm.assert_equal(array, original)
925974

926975

976+
@pytest.mark.parametrize("as_index", [True, False])
927977
@pytest.mark.parametrize(
928978
"values",
929979
[
@@ -932,9 +982,23 @@ def test_to_numpy_extra(array):
932982
pd.PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"),
933983
],
934984
)
935-
@pytest.mark.parametrize("klass", [list, np.array, pd.array, pd.Series])
936-
def test_searchsorted_datetimelike_with_listlike(values, klass):
985+
@pytest.mark.parametrize(
986+
"klass",
987+
[
988+
list,
989+
np.array,
990+
pd.array,
991+
pd.Series,
992+
pd.Index,
993+
pd.Categorical,
994+
pd.CategoricalIndex,
995+
],
996+
)
997+
def test_searchsorted_datetimelike_with_listlike(values, klass, as_index):
937998
# https://github.com/pandas-dev/pandas/issues/32762
999+
if not as_index:
1000+
values = values._data
1001+
9381002
result = values.searchsorted(klass(values))
9391003
expected = np.array([0, 1], dtype=result.dtype)
9401004

0 commit comments

Comments
 (0)