|
11 | 11 | from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds
|
12 | 12 | from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64
|
13 | 13 | from pandas._typing import DatetimeLikeScalar
|
| 14 | +from pandas.compat import set_function_name |
14 | 15 | from pandas.compat.numpy import function as nv
|
15 | 16 | from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
|
16 | 17 | from pandas.util._decorators import Appender, Substitution
|
|
37 | 38 | from pandas.core.dtypes.inference import is_array_like
|
38 | 39 | from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
|
39 | 40 |
|
40 |
| -from pandas.core import missing, nanops |
| 41 | +from pandas.core import missing, nanops, ops |
41 | 42 | from pandas.core.algorithms import checked_add_with_arr, take, unique1d, value_counts
|
42 | 43 | import pandas.core.common as com
|
43 | 44 | from pandas.core.indexers import check_bool_array_indexer
|
44 | 45 | from pandas.core.ops.common import unpack_zerodim_and_defer
|
45 |
| -from pandas.core.ops.invalid import make_invalid_op |
| 46 | +from pandas.core.ops.invalid import invalid_comparison, make_invalid_op |
46 | 47 |
|
47 | 48 | from pandas.tseries import frequencies
|
48 | 49 | from pandas.tseries.offsets import DateOffset, Tick
|
49 | 50 |
|
50 | 51 | from .base import ExtensionArray, ExtensionOpsMixin
|
51 | 52 |
|
52 | 53 |
|
| 54 | +def _datetimelike_array_cmp(cls, op): |
| 55 | + """ |
| 56 | + Wrap comparison operations to convert Timestamp/Timedelta/Period-like to |
| 57 | + boxed scalars/arrays. |
| 58 | + """ |
| 59 | + opname = f"__{op.__name__}__" |
| 60 | + nat_result = opname == "__ne__" |
| 61 | + |
| 62 | + @unpack_zerodim_and_defer(opname) |
| 63 | + def wrapper(self, other): |
| 64 | + |
| 65 | + if isinstance(other, str): |
| 66 | + try: |
| 67 | + # GH#18435 strings get a pass from tzawareness compat |
| 68 | + other = self._scalar_from_string(other) |
| 69 | + except ValueError: |
| 70 | + # failed to parse as Timestamp/Timedelta/Period |
| 71 | + return invalid_comparison(self, other, op) |
| 72 | + |
| 73 | + if isinstance(other, self._recognized_scalars) or other is NaT: |
| 74 | + other = self._scalar_type(other) |
| 75 | + self._check_compatible_with(other) |
| 76 | + |
| 77 | + other_i8 = self._unbox_scalar(other) |
| 78 | + |
| 79 | + result = op(self.view("i8"), other_i8) |
| 80 | + if isna(other): |
| 81 | + result.fill(nat_result) |
| 82 | + |
| 83 | + elif not is_list_like(other): |
| 84 | + return invalid_comparison(self, other, op) |
| 85 | + |
| 86 | + elif len(other) != len(self): |
| 87 | + raise ValueError("Lengths must match") |
| 88 | + |
| 89 | + else: |
| 90 | + if isinstance(other, list): |
| 91 | + # TODO: could use pd.Index to do inference? |
| 92 | + other = np.array(other) |
| 93 | + |
| 94 | + if not isinstance(other, (np.ndarray, type(self))): |
| 95 | + return invalid_comparison(self, other, op) |
| 96 | + |
| 97 | + if is_object_dtype(other): |
| 98 | + # We have to use comp_method_OBJECT_ARRAY instead of numpy |
| 99 | + # comparison otherwise it would fail to raise when |
| 100 | + # comparing tz-aware and tz-naive |
| 101 | + with np.errstate(all="ignore"): |
| 102 | + result = ops.comp_method_OBJECT_ARRAY( |
| 103 | + op, self.astype(object), other |
| 104 | + ) |
| 105 | + o_mask = isna(other) |
| 106 | + |
| 107 | + elif not type(self)._is_recognized_dtype(other.dtype): |
| 108 | + return invalid_comparison(self, other, op) |
| 109 | + |
| 110 | + else: |
| 111 | + # For PeriodDType this casting is unnecessary |
| 112 | + other = type(self)._from_sequence(other) |
| 113 | + self._check_compatible_with(other) |
| 114 | + |
| 115 | + result = op(self.view("i8"), other.view("i8")) |
| 116 | + o_mask = other._isnan |
| 117 | + |
| 118 | + if o_mask.any(): |
| 119 | + result[o_mask] = nat_result |
| 120 | + |
| 121 | + if self._hasnans: |
| 122 | + result[self._isnan] = nat_result |
| 123 | + |
| 124 | + return result |
| 125 | + |
| 126 | + return set_function_name(wrapper, opname, cls) |
| 127 | + |
| 128 | + |
53 | 129 | class AttributesMixin:
|
54 | 130 | _data: np.ndarray
|
55 | 131 |
|
@@ -934,6 +1010,7 @@ def _is_unique(self):
|
934 | 1010 |
|
935 | 1011 | # ------------------------------------------------------------------
|
936 | 1012 | # Arithmetic Methods
|
| 1013 | + _create_comparison_method = classmethod(_datetimelike_array_cmp) |
937 | 1014 |
|
938 | 1015 | # pow is invalid for all three subclasses; TimedeltaArray will override
|
939 | 1016 | # the multiplication and division ops
|
@@ -1485,6 +1562,8 @@ def mean(self, skipna=True):
|
1485 | 1562 | return self._box_func(result)
|
1486 | 1563 |
|
1487 | 1564 |
|
| 1565 | +DatetimeLikeArrayMixin._add_comparison_ops() |
| 1566 | + |
1488 | 1567 | # -------------------------------------------------------------------
|
1489 | 1568 | # Shared Constructor Helpers
|
1490 | 1569 |
|
|
0 commit comments