Skip to content

REF: implement _validate_comparison_value #33716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,24 @@ def _datetimelike_array_cmp(cls, op):
opname = f"__{op.__name__}__"
nat_result = opname == "__ne__"

@unpack_zerodim_and_defer(opname)
def wrapper(self, other):
class InvalidComparison(Exception):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this meant to be internal or external?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk i would document that as such (followon ok)

pass

def _validate_comparison_value(self, other):
if isinstance(other, str):
try:
# GH#18435 strings get a pass from tzawareness compat
other = self._scalar_from_string(other)
except ValueError:
# failed to parse as Timestamp/Timedelta/Period
return invalid_comparison(self, other, op)
raise InvalidComparison(other)

if isinstance(other, self._recognized_scalars) or other is NaT:
other = self._scalar_type(other)
self._check_compatible_with(other)

other_i8 = self._unbox_scalar(other)

result = op(self.view("i8"), other_i8)
if isna(other):
result.fill(nat_result)

elif not is_list_like(other):
return invalid_comparison(self, other, op)
raise InvalidComparison(other)

elif len(other) != len(self):
raise ValueError("Lengths must match")
Expand All @@ -93,34 +88,50 @@ def wrapper(self, other):
other = np.array(other)

if not isinstance(other, (np.ndarray, type(self))):
return invalid_comparison(self, other, op)

if is_object_dtype(other):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(
op, self.astype(object), other
)
o_mask = isna(other)
raise InvalidComparison(other)

elif is_object_dtype(other.dtype):
pass

elif not type(self)._is_recognized_dtype(other.dtype):
return invalid_comparison(self, other, op)
raise InvalidComparison(other)

else:
# For PeriodDType this casting is unnecessary
# TODO: use Index to do inference?
other = type(self)._from_sequence(other)
self._check_compatible_with(other)

result = op(self.view("i8"), other.view("i8"))
o_mask = other._isnan
return other

if o_mask.any():
result[o_mask] = nat_result
@unpack_zerodim_and_defer(opname)
def wrapper(self, other):

if self._hasnans:
result[self._isnan] = nat_result
try:
other = _validate_comparison_value(self, other)
except InvalidComparison:
return invalid_comparison(self, other, op)

dtype = getattr(other, "dtype", None)
if is_object_dtype(dtype):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
return result

if isinstance(other, self._scalar_type) or other is NaT:
other_i8 = self._unbox_scalar(other)
else:
# Then type(other) == type(self)
other_i8 = other.asi8

result = op(self.asi8, other_i8)

o_mask = isna(other)
if self._hasnans | np.any(o_mask):
result[self._isnan | o_mask] = nat_result

return result

Expand Down