From dc5568ffdf426875feb678fb1379c204a4a85c06 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Mon, 20 Apr 2020 20:17:27 -0700 Subject: [PATCH 1/3] REF: implement _validate_comparison_value --- pandas/core/arrays/datetimelike.py | 64 +++++++++++++++++++----------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 27b2ed822a49f..89dff5bdf2b7f 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -60,29 +60,24 @@ def _datetimelike_array_cmp(cls, op): opname = f"__{op.__name__}__" nat_result = opname == "__ne__" - @unpack_zerodim_and_defer(opname) - def wrapper(self, other): + class InvalidComparison(Exception): + pass + def _validate_comparison_value(self, other): if isinstance(other, str): try: # GH#18435 strings get a pass from tzawareness compat other = self._scalar_from_string(other) except ValueError: # failed to parse as Timestamp/Timedelta/Period - return invalid_comparison(self, other, op) + raise InvalidComparison(other) if isinstance(other, self._recognized_scalars) or other is NaT: other = self._scalar_type(other) self._check_compatible_with(other) - other_i8 = self._unbox_scalar(other) - - result = op(self.view("i8"), other_i8) - if isna(other): - result.fill(nat_result) - elif not is_list_like(other): - return invalid_comparison(self, other, op) + raise InvalidComparison(other) elif len(other) != len(self): raise ValueError("Lengths must match") @@ -93,9 +88,39 @@ def wrapper(self, other): other = np.array(other) if not isinstance(other, (np.ndarray, type(self))): - return invalid_comparison(self, other, op) + raise InvalidComparison(other) + + elif is_object_dtype(other.dtype): + pass + + elif not type(self)._is_recognized_dtype(other.dtype): + raise InvalidComparison(other) + + else: + # For PeriodDType this casting is unnecessary + other = type(self)._from_sequence(other) + self._check_compatible_with(other) + + return other + + @unpack_zerodim_and_defer(opname) + def wrapper(self, other): + + try: + other = _validate_comparison_value(self, other) + except InvalidComparison: + return invalid_comparison(self, other, op) - if is_object_dtype(other): + if isinstance(other, self._scalar_type) or other is NaT: + other_i8 = self._unbox_scalar(other) + + result = op(self.view("i8"), other_i8) + o_mask = isna(other) + + else: + # At this point we have either an ndarray[object] or our own type + + if is_object_dtype(other.dtype): # We have to use comp_method_OBJECT_ARRAY instead of numpy # comparison otherwise it would fail to raise when # comparing tz-aware and tz-naive @@ -105,22 +130,13 @@ def wrapper(self, other): ) o_mask = isna(other) - elif not type(self)._is_recognized_dtype(other.dtype): - return invalid_comparison(self, other, op) - else: - # For PeriodDType this casting is unnecessary - other = type(self)._from_sequence(other) - self._check_compatible_with(other) - + # Then type(other) == type(self) result = op(self.view("i8"), other.view("i8")) o_mask = other._isnan - if o_mask.any(): - result[o_mask] = nat_result - - if self._hasnans: - result[self._isnan] = nat_result + if self._hasnans | np.any(o_mask): + result[self._isnan | o_mask] = nat_result return result From 873c7d45e4df8915a5f5cfc91cf961649f12e0eb Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 21 Apr 2020 07:41:40 -0700 Subject: [PATCH 2/3] CLN: simplify comparison method --- pandas/core/arrays/datetimelike.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 89dff5bdf2b7f..804b1b649ea4c 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -98,6 +98,7 @@ def _validate_comparison_value(self, other): else: # For PeriodDType this casting is unnecessary + # TODO: use Index to do inference? other = type(self)._from_sequence(other) self._check_compatible_with(other) @@ -111,11 +112,13 @@ def wrapper(self, other): except InvalidComparison: return invalid_comparison(self, other, op) + o_mask = isna(other) + i8vals = self.asi8 + if isinstance(other, self._scalar_type) or other is NaT: other_i8 = self._unbox_scalar(other) - result = op(self.view("i8"), other_i8) - o_mask = isna(other) + result = op(i8vals, other_i8) else: # At this point we have either an ndarray[object] or our own type @@ -128,12 +131,11 @@ def wrapper(self, other): result = ops.comp_method_OBJECT_ARRAY( op, self.astype(object), other ) - o_mask = isna(other) else: # Then type(other) == type(self) - result = op(self.view("i8"), other.view("i8")) - o_mask = other._isnan + other_i8 = other.asi8 + result = op(i8vals, other_i8) if self._hasnans | np.any(o_mask): result[self._isnan | o_mask] = nat_result From d3d52d8d4aeda4eb9d16403054db347cf9869011 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 21 Apr 2020 12:28:44 -0700 Subject: [PATCH 3/3] reorder checks --- pandas/core/arrays/datetimelike.py | 31 ++++++++++++------------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 804b1b649ea4c..63932e64d3b05 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -112,31 +112,24 @@ def wrapper(self, other): except InvalidComparison: return invalid_comparison(self, other, op) - o_mask = isna(other) - i8vals = self.asi8 + dtype = getattr(other, "dtype", None) + if is_object_dtype(dtype): + # We have to use comp_method_OBJECT_ARRAY instead of numpy + # comparison otherwise it would fail to raise when + # comparing tz-aware and tz-naive + with np.errstate(all="ignore"): + result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other) + return result if isinstance(other, self._scalar_type) or other is NaT: other_i8 = self._unbox_scalar(other) - - result = op(i8vals, other_i8) - else: - # At this point we have either an ndarray[object] or our own type - - if is_object_dtype(other.dtype): - # We have to use comp_method_OBJECT_ARRAY instead of numpy - # comparison otherwise it would fail to raise when - # comparing tz-aware and tz-naive - with np.errstate(all="ignore"): - result = ops.comp_method_OBJECT_ARRAY( - op, self.astype(object), other - ) + # Then type(other) == type(self) + other_i8 = other.asi8 - else: - # Then type(other) == type(self) - other_i8 = other.asi8 - result = op(i8vals, other_i8) + result = op(self.asi8, other_i8) + o_mask = isna(other) if self._hasnans | np.any(o_mask): result[self._isnan | o_mask] = nat_result