Skip to content

Commit c22d022

Browse files
jbrockmendeljreback
authored andcommitted
REF/TST: PeriodArray comparisons with listlike (#30654)
1 parent 1c35fba commit c22d022

File tree

3 files changed

+112
-10
lines changed

3 files changed

+112
-10
lines changed

pandas/core/arrays/period.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
is_datetime64_dtype,
3030
is_float_dtype,
3131
is_list_like,
32+
is_object_dtype,
3233
is_period_dtype,
3334
pandas_dtype,
3435
)
@@ -41,6 +42,7 @@
4142
)
4243
from pandas.core.dtypes.missing import isna, notna
4344

45+
from pandas.core import ops
4446
import pandas.core.algorithms as algos
4547
from pandas.core.arrays import datetimelike as dtl
4648
import pandas.core.common as com
@@ -92,22 +94,44 @@ def wrapper(self, other):
9294
self._check_compatible_with(other)
9395

9496
result = ordinal_op(other.ordinal)
95-
elif isinstance(other, cls):
96-
self._check_compatible_with(other)
97-
98-
result = ordinal_op(other.asi8)
99-
100-
mask = self._isnan | other._isnan
101-
if mask.any():
102-
result[mask] = nat_result
10397

104-
return result
10598
elif other is NaT:
10699
result = np.empty(len(self.asi8), dtype=bool)
107100
result.fill(nat_result)
108-
else:
101+
102+
elif not is_list_like(other):
109103
return invalid_comparison(self, other, op)
110104

105+
else:
106+
if isinstance(other, list):
107+
# TODO: could use pd.Index to do inference?
108+
other = np.array(other)
109+
110+
if not isinstance(other, (np.ndarray, cls)):
111+
return invalid_comparison(self, other, op)
112+
113+
if is_object_dtype(other):
114+
with np.errstate(all="ignore"):
115+
result = ops.comp_method_OBJECT_ARRAY(
116+
op, self.astype(object), other
117+
)
118+
o_mask = isna(other)
119+
120+
elif not is_period_dtype(other):
121+
# e.g. is_timedelta64_dtype(other)
122+
return invalid_comparison(self, other, op)
123+
124+
else:
125+
assert isinstance(other, cls), type(other)
126+
127+
self._check_compatible_with(other)
128+
129+
result = ordinal_op(other.asi8)
130+
o_mask = other._isnan
131+
132+
if o_mask.any():
133+
result[o_mask] = nat_result
134+
111135
if self._hasnans:
112136
result[self._isnan] = nat_result
113137

pandas/core/indexes/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def cmp_method(self, other):
107107
if is_object_dtype(self) and isinstance(other, ABCCategorical):
108108
left = type(other)(self._values, dtype=other.dtype)
109109
return op(left, other)
110+
elif is_object_dtype(self) and isinstance(other, ExtensionArray):
111+
# e.g. PeriodArray
112+
with np.errstate(all="ignore"):
113+
result = op(self.values, other)
114+
110115
elif is_object_dtype(self) and not isinstance(self, ABCMultiIndex):
111116
# don't pass MultiIndex
112117
with np.errstate(all="ignore"):

pandas/tests/arithmetic/test_period.py

+73
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,79 @@ def test_compare_invalid_scalar(self, box_with_array, scalar):
5050
parr = tm.box_expected(pi, box_with_array)
5151
assert_invalid_comparison(parr, scalar, box_with_array)
5252

53+
@pytest.mark.parametrize(
54+
"other",
55+
[
56+
pd.date_range("2000", periods=4).array,
57+
pd.timedelta_range("1D", periods=4).array,
58+
np.arange(4),
59+
np.arange(4).astype(np.float64),
60+
list(range(4)),
61+
],
62+
)
63+
def test_compare_invalid_listlike(self, box_with_array, other):
64+
pi = pd.period_range("2000", periods=4)
65+
parr = tm.box_expected(pi, box_with_array)
66+
assert_invalid_comparison(parr, other, box_with_array)
67+
68+
@pytest.mark.parametrize("other_box", [list, np.array, lambda x: x.astype(object)])
69+
def test_compare_object_dtype(self, box_with_array, other_box):
70+
pi = pd.period_range("2000", periods=5)
71+
parr = tm.box_expected(pi, box_with_array)
72+
73+
xbox = np.ndarray if box_with_array is pd.Index else box_with_array
74+
75+
other = other_box(pi)
76+
77+
expected = np.array([True, True, True, True, True])
78+
expected = tm.box_expected(expected, xbox)
79+
80+
result = parr == other
81+
tm.assert_equal(result, expected)
82+
result = parr <= other
83+
tm.assert_equal(result, expected)
84+
result = parr >= other
85+
tm.assert_equal(result, expected)
86+
87+
result = parr != other
88+
tm.assert_equal(result, ~expected)
89+
result = parr < other
90+
tm.assert_equal(result, ~expected)
91+
result = parr > other
92+
tm.assert_equal(result, ~expected)
93+
94+
other = other_box(pi[::-1])
95+
96+
expected = np.array([False, False, True, False, False])
97+
expected = tm.box_expected(expected, xbox)
98+
result = parr == other
99+
tm.assert_equal(result, expected)
100+
101+
expected = np.array([True, True, True, False, False])
102+
expected = tm.box_expected(expected, xbox)
103+
result = parr <= other
104+
tm.assert_equal(result, expected)
105+
106+
expected = np.array([False, False, True, True, True])
107+
expected = tm.box_expected(expected, xbox)
108+
result = parr >= other
109+
tm.assert_equal(result, expected)
110+
111+
expected = np.array([True, True, False, True, True])
112+
expected = tm.box_expected(expected, xbox)
113+
result = parr != other
114+
tm.assert_equal(result, expected)
115+
116+
expected = np.array([True, True, False, False, False])
117+
expected = tm.box_expected(expected, xbox)
118+
result = parr < other
119+
tm.assert_equal(result, expected)
120+
121+
expected = np.array([False, False, False, True, True])
122+
expected = tm.box_expected(expected, xbox)
123+
result = parr > other
124+
tm.assert_equal(result, expected)
125+
53126

54127
class TestPeriodIndexComparisons:
55128
# TODO: parameterize over boxes

0 commit comments

Comments
 (0)