Skip to content

Standardize datetimelike casting behavior where/setitem/searchsorted/comparison #34055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
from pandas.core.algorithms import _get_data_algo, factorize, take_1d, unique1d
from pandas.core.array_algos.transforms import shift
from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray
from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs
from pandas.core.base import (
ExtensionArray,
NoNewAttributesMixin,
PandasObject,
_shared_docs,
)
import pandas.core.common as com
from pandas.core.construction import array, extract_array, sanitize_array
from pandas.core.indexers import check_array_indexer, deprecate_ndim_indexing
Expand Down Expand Up @@ -124,17 +129,20 @@ def func(self, other):
"scalar, which is not a category."
)
else:

# allow categorical vs object dtype array comparisons for equality
# these are only positional comparisons
if opname in ["__eq__", "__ne__"]:
return getattr(np.array(self), opname)(np.array(other))
if opname not in ["__eq__", "__ne__"]:
raise TypeError(
f"Cannot compare a Categorical for op {opname} with "
f"type {type(other)}.\nIf you want to compare values, "
"use 'np.asarray(cat) <op> other'."
)

raise TypeError(
f"Cannot compare a Categorical for op {opname} with "
f"type {type(other)}.\nIf you want to compare values, "
"use 'np.asarray(cat) <op> other'."
)
if isinstance(other, ExtensionArray) and needs_i8_conversion(other):
# We would return NotImplemented here, but that messes up
# ExtensionIndex's wrapped methods
return op(other, self)
return getattr(np.array(self), opname)(np.array(other))

func.__name__ = opname

Expand Down
12 changes: 3 additions & 9 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _validate_comparison_value(self, other):

@unpack_zerodim_and_defer(opname)
def wrapper(self, other):

try:
other = _validate_comparison_value(self, other)
except InvalidComparison:
Expand Down Expand Up @@ -759,12 +758,7 @@ def _validate_shift_value(self, fill_value):
return self._unbox(fill_value)

def _validate_listlike(
self,
value,
opname: str,
cast_str: bool = False,
cast_cat: bool = False,
allow_object: bool = False,
self, value, opname: str, cast_str: bool = False, allow_object: bool = False,
):
if isinstance(value, type(self)):
return value
Expand All @@ -783,7 +777,7 @@ def _validate_listlike(
except ValueError:
pass

if cast_cat and is_categorical_dtype(value.dtype):
if is_categorical_dtype(value.dtype):
# e.g. we have a Categorical holding self.dtype
if is_dtype_equal(value.categories.dtype, self.dtype):
# TODO: do we need equal dtype or just comparable?
Expand Down Expand Up @@ -868,7 +862,7 @@ def _validate_where_value(self, other):
raise TypeError(f"Where requires matching dtype, not {type(other)}")

else:
other = self._validate_listlike(other, "where", cast_cat=True)
other = self._validate_listlike(other, "where")
self._check_compatible_with(other, setitem=True)

self._check_compatible_with(other, setitem=True)
Expand Down
68 changes: 66 additions & 2 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,41 @@ def test_compare_len1_raises(self):
with pytest.raises(ValueError, match="Lengths must match"):
idx <= idx[[0]]

@pytest.mark.parametrize("reverse", [True, False])
@pytest.mark.parametrize("as_index", [True, False])
def test_compare_categorical_dtype(self, arr1d, as_index, reverse, ordered):
other = pd.Categorical(arr1d, ordered=ordered)
if as_index:
other = pd.CategoricalIndex(other)

left, right = arr1d, other
if reverse:
left, right = right, left

ones = np.ones(arr1d.shape, dtype=bool)
zeros = ~ones

result = left == right
tm.assert_numpy_array_equal(result, ones)

result = left != right
tm.assert_numpy_array_equal(result, zeros)

if not reverse and not as_index:
# Otherwise Categorical raises TypeError bc it is not ordered
# TODO: we should probably get the same behavior regardless?
result = left < right
tm.assert_numpy_array_equal(result, zeros)

result = left <= right
tm.assert_numpy_array_equal(result, ones)

result = left > right
tm.assert_numpy_array_equal(result, zeros)

result = left >= right
tm.assert_numpy_array_equal(result, ones)

def test_take(self):
data = np.arange(100, dtype="i8") * 24 * 3600 * 10 ** 9
np.random.shuffle(data)
Expand Down Expand Up @@ -251,6 +286,20 @@ def test_setitem_str_array(self, arr1d):

tm.assert_equal(arr1d, expected)

@pytest.mark.parametrize("as_index", [True, False])
def test_setitem_categorical(self, arr1d, as_index):
expected = arr1d.copy()[::-1]
if not isinstance(expected, PeriodArray):
expected = expected._with_freq(None)

cat = pd.Categorical(arr1d)
if as_index:
cat = pd.CategoricalIndex(cat)

arr1d[:] = cat[::-1]

tm.assert_equal(arr1d, expected)

def test_setitem_raises(self):
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
arr = self.array_cls(data, freq="D")
Expand Down Expand Up @@ -924,6 +973,7 @@ def test_to_numpy_extra(array):
tm.assert_equal(array, original)


@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize(
"values",
[
Expand All @@ -932,9 +982,23 @@ def test_to_numpy_extra(array):
pd.PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"),
],
)
@pytest.mark.parametrize("klass", [list, np.array, pd.array, pd.Series])
def test_searchsorted_datetimelike_with_listlike(values, klass):
@pytest.mark.parametrize(
"klass",
[
list,
np.array,
pd.array,
pd.Series,
pd.Index,
pd.Categorical,
pd.CategoricalIndex,
],
)
def test_searchsorted_datetimelike_with_listlike(values, klass, as_index):
# https://github.com/pandas-dev/pandas/issues/32762
if not as_index:
values = values._data

result = values.searchsorted(klass(values))
expected = np.array([0, 1], dtype=result.dtype)

Expand Down