Skip to content

Commit 3253cb0

Browse files
jbrockmendeljreback
authored andcommitted
BUG: listlike comparisons for DTA and TDA (#30705)
1 parent 97bba51 commit 3253cb0

File tree

5 files changed

+117
-22
lines changed

5 files changed

+117
-22
lines changed

pandas/core/arrays/datetimes.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,9 @@ def wrapper(self, other):
161161
raise ValueError("Lengths must match")
162162
else:
163163
if isinstance(other, list):
164-
try:
165-
other = type(self)._from_sequence(other)
166-
except ValueError:
167-
other = np.array(other, dtype=np.object_)
168-
elif not isinstance(other, (np.ndarray, DatetimeArray)):
164+
other = np.array(other)
165+
166+
if not isinstance(other, (np.ndarray, cls)):
169167
# Following Timestamp convention, __eq__ is all-False
170168
# and __ne__ is all True, others raise TypeError.
171169
return invalid_comparison(self, other, op)
@@ -179,20 +177,14 @@ def wrapper(self, other):
179177
op, self.astype(object), other
180178
)
181179
o_mask = isna(other)
180+
182181
elif not (is_datetime64_dtype(other) or is_datetime64tz_dtype(other)):
183182
# e.g. is_timedelta64_dtype(other)
184183
return invalid_comparison(self, other, op)
184+
185185
else:
186186
self._assert_tzawareness_compat(other)
187-
188-
if (
189-
is_datetime64_dtype(other)
190-
and not is_datetime64_ns_dtype(other)
191-
or not hasattr(other, "asi8")
192-
):
193-
# e.g. other.dtype == 'datetime64[s]'
194-
# or an object-dtype ndarray
195-
other = type(self)._from_sequence(other)
187+
other = type(self)._from_sequence(other)
196188

197189
result = op(self.view("i8"), other.view("i8"))
198190
o_mask = other._isnan

pandas/core/arrays/timedeltas.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from pandas.core.dtypes.missing import isna
4040

41-
from pandas.core import nanops
41+
from pandas.core import nanops, ops
4242
from pandas.core.algorithms import checked_add_with_arr
4343
import pandas.core.common as com
4444
from pandas.core.ops.common import unpack_zerodim_and_defer
@@ -103,15 +103,29 @@ def wrapper(self, other):
103103
raise ValueError("Lengths must match")
104104

105105
else:
106-
try:
107-
other = type(self)._from_sequence(other)._data
108-
except (ValueError, TypeError):
106+
if isinstance(other, list):
107+
other = np.array(other)
108+
109+
if not isinstance(other, (np.ndarray, cls)):
110+
return invalid_comparison(self, other, op)
111+
112+
if is_object_dtype(other):
113+
with np.errstate(all="ignore"):
114+
result = ops.comp_method_OBJECT_ARRAY(
115+
op, self.astype(object), other
116+
)
117+
o_mask = isna(other)
118+
119+
elif not is_timedelta64_dtype(other):
120+
# e.g. other is datetimearray
109121
return invalid_comparison(self, other, op)
110122

111-
result = op(self.view("i8"), other.view("i8"))
112-
result = com.values_from_object(result)
123+
else:
124+
other = type(self)._from_sequence(other)
125+
126+
result = op(self.view("i8"), other.view("i8"))
127+
o_mask = other._isnan
113128

114-
o_mask = np.array(isna(other))
115129
if o_mask.any():
116130
result[o_mask] = nat_result
117131

pandas/tests/arithmetic/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def assert_invalid_comparison(left, right, box):
7070
result = right != left
7171
tm.assert_equal(result, ~expected)
7272

73-
msg = "Invalid comparison between"
73+
msg = "Invalid comparison between|Cannot compare type|not supported between"
7474
with pytest.raises(TypeError, match=msg):
7575
left < right
7676
with pytest.raises(TypeError, match=msg):

pandas/tests/arithmetic/test_datetime64.py

+46
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,52 @@ def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_arra
8585
dtarr = tm.box_expected(rng, box_with_array)
8686
assert_invalid_comparison(dtarr, other, box_with_array)
8787

88+
@pytest.mark.parametrize(
89+
"other",
90+
[
91+
list(range(10)),
92+
np.arange(10),
93+
np.arange(10).astype(np.float32),
94+
np.arange(10).astype(object),
95+
pd.timedelta_range("1ns", periods=10).array,
96+
np.array(pd.timedelta_range("1ns", periods=10)),
97+
list(pd.timedelta_range("1ns", periods=10)),
98+
pd.timedelta_range("1 Day", periods=10).astype(object),
99+
pd.period_range("1971-01-01", freq="D", periods=10).array,
100+
pd.period_range("1971-01-01", freq="D", periods=10).astype(object),
101+
],
102+
)
103+
def test_dt64arr_cmp_arraylike_invalid(self, other, tz_naive_fixture):
104+
# We don't parametrize this over box_with_array because listlike
105+
# other plays poorly with assert_invalid_comparison reversed checks
106+
tz = tz_naive_fixture
107+
108+
dta = date_range("1970-01-01", freq="ns", periods=10, tz=tz)._data
109+
assert_invalid_comparison(dta, other, tm.to_array)
110+
111+
def test_dt64arr_cmp_mixed_invalid(self, tz_naive_fixture):
112+
tz = tz_naive_fixture
113+
114+
dta = date_range("1970-01-01", freq="h", periods=5, tz=tz)._data
115+
116+
other = np.array([0, 1, 2, dta[3], pd.Timedelta(days=1)])
117+
result = dta == other
118+
expected = np.array([False, False, False, True, False])
119+
tm.assert_numpy_array_equal(result, expected)
120+
121+
result = dta != other
122+
tm.assert_numpy_array_equal(result, ~expected)
123+
124+
msg = "Invalid comparison between|Cannot compare type|not supported between"
125+
with pytest.raises(TypeError, match=msg):
126+
dta < other
127+
with pytest.raises(TypeError, match=msg):
128+
dta > other
129+
with pytest.raises(TypeError, match=msg):
130+
dta <= other
131+
with pytest.raises(TypeError, match=msg):
132+
dta >= other
133+
88134
def test_dt64arr_nat_comparison(self, tz_naive_fixture, box_with_array):
89135
# GH#22242, GH#22163 DataFrame considered NaT == ts incorrectly
90136
tz = tz_naive_fixture

pandas/tests/arithmetic/test_timedelta64.py

+43
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,49 @@ def test_td64_comparisons_invalid(self, box_with_array, invalid):
7676

7777
assert_invalid_comparison(obj, invalid, box)
7878

79+
@pytest.mark.parametrize(
80+
"other",
81+
[
82+
list(range(10)),
83+
np.arange(10),
84+
np.arange(10).astype(np.float32),
85+
np.arange(10).astype(object),
86+
pd.date_range("1970-01-01", periods=10, tz="UTC").array,
87+
np.array(pd.date_range("1970-01-01", periods=10)),
88+
list(pd.date_range("1970-01-01", periods=10)),
89+
pd.date_range("1970-01-01", periods=10).astype(object),
90+
pd.period_range("1971-01-01", freq="D", periods=10).array,
91+
pd.period_range("1971-01-01", freq="D", periods=10).astype(object),
92+
],
93+
)
94+
def test_td64arr_cmp_arraylike_invalid(self, other):
95+
# We don't parametrize this over box_with_array because listlike
96+
# other plays poorly with assert_invalid_comparison reversed checks
97+
98+
rng = timedelta_range("1 days", periods=10)._data
99+
assert_invalid_comparison(rng, other, tm.to_array)
100+
101+
def test_td64arr_cmp_mixed_invalid(self):
102+
rng = timedelta_range("1 days", periods=5)._data
103+
104+
other = np.array([0, 1, 2, rng[3], pd.Timestamp.now()])
105+
result = rng == other
106+
expected = np.array([False, False, False, True, False])
107+
tm.assert_numpy_array_equal(result, expected)
108+
109+
result = rng != other
110+
tm.assert_numpy_array_equal(result, ~expected)
111+
112+
msg = "Invalid comparison between|Cannot compare type|not supported between"
113+
with pytest.raises(TypeError, match=msg):
114+
rng < other
115+
with pytest.raises(TypeError, match=msg):
116+
rng > other
117+
with pytest.raises(TypeError, match=msg):
118+
rng <= other
119+
with pytest.raises(TypeError, match=msg):
120+
rng >= other
121+
79122

80123
class TestTimedelta64ArrayComparisons:
81124
# TODO: All of these need to be parametrized over box

0 commit comments

Comments
 (0)