diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index efb53e5a89314..59ae2dc171cf8 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -45,7 +45,12 @@ from pandas.core.algorithms import _get_data_algo, factorize, take_1d, unique1d from pandas.core.array_algos.transforms import shift from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray -from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs +from pandas.core.base import ( + ExtensionArray, + NoNewAttributesMixin, + PandasObject, + _shared_docs, +) import pandas.core.common as com from pandas.core.construction import array, extract_array, sanitize_array from pandas.core.indexers import check_array_indexer, deprecate_ndim_indexing @@ -124,17 +129,20 @@ def func(self, other): "scalar, which is not a category." ) else: - # allow categorical vs object dtype array comparisons for equality # these are only positional comparisons - if opname in ["__eq__", "__ne__"]: - return getattr(np.array(self), opname)(np.array(other)) + if opname not in ["__eq__", "__ne__"]: + raise TypeError( + f"Cannot compare a Categorical for op {opname} with " + f"type {type(other)}.\nIf you want to compare values, " + "use 'np.asarray(cat) other'." + ) - raise TypeError( - f"Cannot compare a Categorical for op {opname} with " - f"type {type(other)}.\nIf you want to compare values, " - "use 'np.asarray(cat) other'." - ) + if isinstance(other, ExtensionArray) and needs_i8_conversion(other): + # We would return NotImplemented here, but that messes up + # ExtensionIndex's wrapped methods + return op(other, self) + return getattr(np.array(self), opname)(np.array(other)) func.__name__ = opname diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 3aeec70ab72d7..0eec46b3d95f7 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -95,7 +95,6 @@ def _validate_comparison_value(self, other): @unpack_zerodim_and_defer(opname) def wrapper(self, other): - try: other = _validate_comparison_value(self, other) except InvalidComparison: @@ -759,12 +758,7 @@ def _validate_shift_value(self, fill_value): return self._unbox(fill_value) def _validate_listlike( - self, - value, - opname: str, - cast_str: bool = False, - cast_cat: bool = False, - allow_object: bool = False, + self, value, opname: str, cast_str: bool = False, allow_object: bool = False, ): if isinstance(value, type(self)): return value @@ -783,7 +777,7 @@ def _validate_listlike( except ValueError: pass - if cast_cat and is_categorical_dtype(value.dtype): + if is_categorical_dtype(value.dtype): # e.g. we have a Categorical holding self.dtype if is_dtype_equal(value.categories.dtype, self.dtype): # TODO: do we need equal dtype or just comparable? @@ -868,7 +862,7 @@ def _validate_where_value(self, other): raise TypeError(f"Where requires matching dtype, not {type(other)}") else: - other = self._validate_listlike(other, "where", cast_cat=True) + other = self._validate_listlike(other, "where") self._check_compatible_with(other, setitem=True) self._check_compatible_with(other, setitem=True) diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 71726c745df0a..61d78034f0747 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -81,6 +81,41 @@ def test_compare_len1_raises(self): with pytest.raises(ValueError, match="Lengths must match"): idx <= idx[[0]] + @pytest.mark.parametrize("reverse", [True, False]) + @pytest.mark.parametrize("as_index", [True, False]) + def test_compare_categorical_dtype(self, arr1d, as_index, reverse, ordered): + other = pd.Categorical(arr1d, ordered=ordered) + if as_index: + other = pd.CategoricalIndex(other) + + left, right = arr1d, other + if reverse: + left, right = right, left + + ones = np.ones(arr1d.shape, dtype=bool) + zeros = ~ones + + result = left == right + tm.assert_numpy_array_equal(result, ones) + + result = left != right + tm.assert_numpy_array_equal(result, zeros) + + if not reverse and not as_index: + # Otherwise Categorical raises TypeError bc it is not ordered + # TODO: we should probably get the same behavior regardless? + result = left < right + tm.assert_numpy_array_equal(result, zeros) + + result = left <= right + tm.assert_numpy_array_equal(result, ones) + + result = left > right + tm.assert_numpy_array_equal(result, zeros) + + result = left >= right + tm.assert_numpy_array_equal(result, ones) + def test_take(self): data = np.arange(100, dtype="i8") * 24 * 3600 * 10 ** 9 np.random.shuffle(data) @@ -251,6 +286,20 @@ def test_setitem_str_array(self, arr1d): tm.assert_equal(arr1d, expected) + @pytest.mark.parametrize("as_index", [True, False]) + def test_setitem_categorical(self, arr1d, as_index): + expected = arr1d.copy()[::-1] + if not isinstance(expected, PeriodArray): + expected = expected._with_freq(None) + + cat = pd.Categorical(arr1d) + if as_index: + cat = pd.CategoricalIndex(cat) + + arr1d[:] = cat[::-1] + + tm.assert_equal(arr1d, expected) + def test_setitem_raises(self): data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9 arr = self.array_cls(data, freq="D") @@ -924,6 +973,7 @@ def test_to_numpy_extra(array): tm.assert_equal(array, original) +@pytest.mark.parametrize("as_index", [True, False]) @pytest.mark.parametrize( "values", [ @@ -932,9 +982,23 @@ def test_to_numpy_extra(array): pd.PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"), ], ) -@pytest.mark.parametrize("klass", [list, np.array, pd.array, pd.Series]) -def test_searchsorted_datetimelike_with_listlike(values, klass): +@pytest.mark.parametrize( + "klass", + [ + list, + np.array, + pd.array, + pd.Series, + pd.Index, + pd.Categorical, + pd.CategoricalIndex, + ], +) +def test_searchsorted_datetimelike_with_listlike(values, klass, as_index): # https://github.com/pandas-dev/pandas/issues/32762 + if not as_index: + values = values._data + result = values.searchsorted(klass(values)) expected = np.array([0, 1], dtype=result.dtype)