|
13 | 13 |
|
14 | 14 | from pandas.util._decorators import cache_readonly
|
15 | 15 | from pandas.errors import PerformanceWarning
|
| 16 | +from pandas import compat |
16 | 17 |
|
17 | 18 | from pandas.core.dtypes.common import (
|
18 | 19 | _NS_DTYPE,
|
| 20 | + is_datetimelike, |
19 | 21 | is_datetime64tz_dtype,
|
20 | 22 | is_datetime64_dtype,
|
21 | 23 | is_timedelta64_dtype,
|
22 | 24 | _ensure_int64)
|
23 | 25 | from pandas.core.dtypes.dtypes import DatetimeTZDtype
|
| 26 | +from pandas.core.dtypes.missing import isna |
| 27 | +from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries |
24 | 28 |
|
| 29 | +import pandas.core.common as com |
25 | 30 | from pandas.core.algorithms import checked_add_with_arr
|
26 | 31 |
|
27 | 32 | from pandas.tseries.frequencies import to_offset, DateOffset
|
28 | 33 | from pandas.tseries.offsets import Tick
|
29 | 34 |
|
30 |
| -from .datetimelike import DatetimeLikeArrayMixin |
| 35 | +from pandas.core.arrays import datetimelike as dtl |
| 36 | + |
| 37 | + |
| 38 | +def _to_m8(key, tz=None): |
| 39 | + """ |
| 40 | + Timestamp-like => dt64 |
| 41 | + """ |
| 42 | + if not isinstance(key, Timestamp): |
| 43 | + # this also converts strings |
| 44 | + key = Timestamp(key, tz=tz) |
| 45 | + |
| 46 | + return np.int64(conversion.pydt_to_i8(key)).view(_NS_DTYPE) |
31 | 47 |
|
32 | 48 |
|
33 | 49 | def _field_accessor(name, field, docstring=None):
|
@@ -68,7 +84,58 @@ def f(self):
|
68 | 84 | return property(f)
|
69 | 85 |
|
70 | 86 |
|
71 |
| -class DatetimeArrayMixin(DatetimeLikeArrayMixin): |
| 87 | +def _dt_array_cmp(opname, cls): |
| 88 | + """ |
| 89 | + Wrap comparison operations to convert datetime-like to datetime64 |
| 90 | + """ |
| 91 | + nat_result = True if opname == '__ne__' else False |
| 92 | + |
| 93 | + def wrapper(self, other): |
| 94 | + meth = getattr(dtl.DatetimeLikeArrayMixin, opname) |
| 95 | + |
| 96 | + if isinstance(other, (datetime, np.datetime64, compat.string_types)): |
| 97 | + if isinstance(other, datetime): |
| 98 | + # GH#18435 strings get a pass from tzawareness compat |
| 99 | + self._assert_tzawareness_compat(other) |
| 100 | + |
| 101 | + other = _to_m8(other, tz=self.tz) |
| 102 | + result = meth(self, other) |
| 103 | + if isna(other): |
| 104 | + result.fill(nat_result) |
| 105 | + else: |
| 106 | + if isinstance(other, list): |
| 107 | + other = type(self)(other) |
| 108 | + elif not isinstance(other, (np.ndarray, ABCIndexClass, ABCSeries)): |
| 109 | + # Following Timestamp convention, __eq__ is all-False |
| 110 | + # and __ne__ is all True, others raise TypeError. |
| 111 | + if opname == '__eq__': |
| 112 | + return np.zeros(shape=self.shape, dtype=bool) |
| 113 | + elif opname == '__ne__': |
| 114 | + return np.ones(shape=self.shape, dtype=bool) |
| 115 | + raise TypeError('%s type object %s' % |
| 116 | + (type(other), str(other))) |
| 117 | + |
| 118 | + if is_datetimelike(other): |
| 119 | + self._assert_tzawareness_compat(other) |
| 120 | + |
| 121 | + result = meth(self, np.asarray(other)) |
| 122 | + result = com._values_from_object(result) |
| 123 | + |
| 124 | + # Make sure to pass an array to result[...]; indexing with |
| 125 | + # Series breaks with older version of numpy |
| 126 | + o_mask = np.array(isna(other)) |
| 127 | + if o_mask.any(): |
| 128 | + result[o_mask] = nat_result |
| 129 | + |
| 130 | + if self.hasnans: |
| 131 | + result[self._isnan] = nat_result |
| 132 | + |
| 133 | + return result |
| 134 | + |
| 135 | + return compat.set_function_name(wrapper, opname, cls) |
| 136 | + |
| 137 | + |
| 138 | +class DatetimeArrayMixin(dtl.DatetimeLikeArrayMixin): |
72 | 139 | """
|
73 | 140 | Assumes that subclass __new__/__init__ defines:
|
74 | 141 | tz
|
@@ -222,6 +289,18 @@ def __iter__(self):
|
222 | 289 | # -----------------------------------------------------------------
|
223 | 290 | # Comparison Methods
|
224 | 291 |
|
| 292 | + @classmethod |
| 293 | + def _add_comparison_methods(cls): |
| 294 | + """add in comparison methods""" |
| 295 | + cls.__eq__ = _dt_array_cmp('__eq__', cls) |
| 296 | + cls.__ne__ = _dt_array_cmp('__ne__', cls) |
| 297 | + cls.__lt__ = _dt_array_cmp('__lt__', cls) |
| 298 | + cls.__gt__ = _dt_array_cmp('__gt__', cls) |
| 299 | + cls.__le__ = _dt_array_cmp('__le__', cls) |
| 300 | + cls.__ge__ = _dt_array_cmp('__ge__', cls) |
| 301 | + # TODO: Some classes pass __eq__ while others pass operator.eq; |
| 302 | + # standardize this. |
| 303 | + |
225 | 304 | def _has_same_tz(self, other):
|
226 | 305 | zzone = self._timezone
|
227 | 306 |
|
@@ -335,7 +414,7 @@ def _add_delta(self, delta):
|
335 | 414 | The result's name is set outside of _add_delta by the calling
|
336 | 415 | method (__add__ or __sub__)
|
337 | 416 | """
|
338 |
| - from pandas.core.arrays.timedelta import TimedeltaArrayMixin |
| 417 | + from pandas.core.arrays.timedeltas import TimedeltaArrayMixin |
339 | 418 |
|
340 | 419 | if isinstance(delta, (Tick, timedelta, np.timedelta64)):
|
341 | 420 | new_values = self._add_delta_td(delta)
|
@@ -1021,3 +1100,6 @@ def to_julian_date(self):
|
1021 | 1100 | self.microsecond / 3600.0 / 1e+6 +
|
1022 | 1101 | self.nanosecond / 3600.0 / 1e+9
|
1023 | 1102 | ) / 24.0)
|
| 1103 | + |
| 1104 | + |
| 1105 | +DatetimeArrayMixin._add_comparison_methods() |
0 commit comments