Skip to content

Commit a3bcbf8

Browse files
authored
BUG: Period.__eq__ numpy scalar (#44182 (comment)) (#44285)
1 parent 669acb4 commit a3bcbf8

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

pandas/_libs/tslibs/period.pyx

+6-2
Original file line numberDiff line numberDiff line change
@@ -1657,8 +1657,12 @@ cdef class _Period(PeriodMixin):
16571657
elif other is NaT:
16581658
return _nat_scalar_rules[op]
16591659
elif util.is_array(other):
1660-
# in particular ndarray[object]; see test_pi_cmp_period
1661-
return np.array([PyObject_RichCompare(self, x, op) for x in other])
1660+
# GH#44285
1661+
if cnp.PyArray_IsZeroDim(other):
1662+
return PyObject_RichCompare(self, other.item(), op)
1663+
else:
1664+
# in particular ndarray[object]; see test_pi_cmp_period
1665+
return np.array([PyObject_RichCompare(self, x, op) for x in other])
16621666
return NotImplemented
16631667

16641668
def __hash__(self):

pandas/tests/arithmetic/test_period.py

+4
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ def test_pi_cmp_period(self):
189189
result = idx.values.reshape(10, 2) < idx[10]
190190
tm.assert_numpy_array_equal(result, exp.reshape(10, 2))
191191

192+
# Tests Period.__richcmp__ against ndarray[object, ndim=0]
193+
result = idx < np.array(idx[10])
194+
tm.assert_numpy_array_equal(result, exp)
195+
192196
# TODO: moved from test_datetime64; de-duplicate with version below
193197
def test_parr_cmp_period_scalar2(self, box_with_array):
194198
xbox = get_expected_box(box_with_array)

pandas/tests/scalar/period/test_period.py

+10
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,16 @@ def test_period_cmp_nat(self):
11481148
assert not left <= right
11491149
assert not left >= right
11501150

1151+
@pytest.mark.parametrize(
1152+
"zerodim_arr, expected",
1153+
((np.array(0), False), (np.array(Period("2000-01", "M")), True)),
1154+
)
1155+
def test_comparison_numpy_zerodim_arr(self, zerodim_arr, expected):
1156+
p = Period("2000-01", "M")
1157+
1158+
assert (p == zerodim_arr) is expected
1159+
assert (zerodim_arr == p) is expected
1160+
11511161

11521162
class TestArithmetic:
11531163
def test_sub_delta(self):

0 commit comments

Comments
 (0)