Skip to content

Commit 7b62393

Browse files
jbrockmendelrhshadrach
authored andcommitted
REF: implement _validate_comparison_value (pandas-dev#33716)
1 parent 705529e commit 7b62393

File tree

1 file changed

+39
-28
lines changed

1 file changed

+39
-28
lines changed

pandas/core/arrays/datetimelike.py

+39-28
Original file line numberDiff line numberDiff line change
@@ -60,29 +60,24 @@ def _datetimelike_array_cmp(cls, op):
6060
opname = f"__{op.__name__}__"
6161
nat_result = opname == "__ne__"
6262

63-
@unpack_zerodim_and_defer(opname)
64-
def wrapper(self, other):
63+
class InvalidComparison(Exception):
64+
pass
6565

66+
def _validate_comparison_value(self, other):
6667
if isinstance(other, str):
6768
try:
6869
# GH#18435 strings get a pass from tzawareness compat
6970
other = self._scalar_from_string(other)
7071
except ValueError:
7172
# failed to parse as Timestamp/Timedelta/Period
72-
return invalid_comparison(self, other, op)
73+
raise InvalidComparison(other)
7374

7475
if isinstance(other, self._recognized_scalars) or other is NaT:
7576
other = self._scalar_type(other)
7677
self._check_compatible_with(other)
7778

78-
other_i8 = self._unbox_scalar(other)
79-
80-
result = op(self.view("i8"), other_i8)
81-
if isna(other):
82-
result.fill(nat_result)
83-
8479
elif not is_list_like(other):
85-
return invalid_comparison(self, other, op)
80+
raise InvalidComparison(other)
8681

8782
elif len(other) != len(self):
8883
raise ValueError("Lengths must match")
@@ -93,34 +88,50 @@ def wrapper(self, other):
9388
other = np.array(other)
9489

9590
if not isinstance(other, (np.ndarray, type(self))):
96-
return invalid_comparison(self, other, op)
97-
98-
if is_object_dtype(other):
99-
# We have to use comp_method_OBJECT_ARRAY instead of numpy
100-
# comparison otherwise it would fail to raise when
101-
# comparing tz-aware and tz-naive
102-
with np.errstate(all="ignore"):
103-
result = ops.comp_method_OBJECT_ARRAY(
104-
op, self.astype(object), other
105-
)
106-
o_mask = isna(other)
91+
raise InvalidComparison(other)
92+
93+
elif is_object_dtype(other.dtype):
94+
pass
10795

10896
elif not type(self)._is_recognized_dtype(other.dtype):
109-
return invalid_comparison(self, other, op)
97+
raise InvalidComparison(other)
11098

11199
else:
112100
# For PeriodDType this casting is unnecessary
101+
# TODO: use Index to do inference?
113102
other = type(self)._from_sequence(other)
114103
self._check_compatible_with(other)
115104

116-
result = op(self.view("i8"), other.view("i8"))
117-
o_mask = other._isnan
105+
return other
118106

119-
if o_mask.any():
120-
result[o_mask] = nat_result
107+
@unpack_zerodim_and_defer(opname)
108+
def wrapper(self, other):
121109

122-
if self._hasnans:
123-
result[self._isnan] = nat_result
110+
try:
111+
other = _validate_comparison_value(self, other)
112+
except InvalidComparison:
113+
return invalid_comparison(self, other, op)
114+
115+
dtype = getattr(other, "dtype", None)
116+
if is_object_dtype(dtype):
117+
# We have to use comp_method_OBJECT_ARRAY instead of numpy
118+
# comparison otherwise it would fail to raise when
119+
# comparing tz-aware and tz-naive
120+
with np.errstate(all="ignore"):
121+
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
122+
return result
123+
124+
if isinstance(other, self._scalar_type) or other is NaT:
125+
other_i8 = self._unbox_scalar(other)
126+
else:
127+
# Then type(other) == type(self)
128+
other_i8 = other.asi8
129+
130+
result = op(self.asi8, other_i8)
131+
132+
o_mask = isna(other)
133+
if self._hasnans | np.any(o_mask):
134+
result[self._isnan | o_mask] = nat_result
124135

125136
return result
126137

0 commit comments

Comments
 (0)