Skip to content

Commit 6b3df29

Browse files
jbrockmendeljreback
authored andcommitted
REF: share comparison methods for DTA/TDA/PA (#30751)
1 parent 2107da1 commit 6b3df29

File tree

4 files changed

+83
-249
lines changed

4 files changed

+83
-249
lines changed

pandas/core/arrays/datetimelike.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds
1212
from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64
1313
from pandas._typing import DatetimeLikeScalar
14+
from pandas.compat import set_function_name
1415
from pandas.compat.numpy import function as nv
1516
from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
1617
from pandas.util._decorators import Appender, Substitution
@@ -37,19 +38,94 @@
3738
from pandas.core.dtypes.inference import is_array_like
3839
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
3940

40-
from pandas.core import missing, nanops
41+
from pandas.core import missing, nanops, ops
4142
from pandas.core.algorithms import checked_add_with_arr, take, unique1d, value_counts
4243
import pandas.core.common as com
4344
from pandas.core.indexers import check_bool_array_indexer
4445
from pandas.core.ops.common import unpack_zerodim_and_defer
45-
from pandas.core.ops.invalid import make_invalid_op
46+
from pandas.core.ops.invalid import invalid_comparison, make_invalid_op
4647

4748
from pandas.tseries import frequencies
4849
from pandas.tseries.offsets import DateOffset, Tick
4950

5051
from .base import ExtensionArray, ExtensionOpsMixin
5152

5253

54+
def _datetimelike_array_cmp(cls, op):
55+
"""
56+
Wrap comparison operations to convert Timestamp/Timedelta/Period-like to
57+
boxed scalars/arrays.
58+
"""
59+
opname = f"__{op.__name__}__"
60+
nat_result = opname == "__ne__"
61+
62+
@unpack_zerodim_and_defer(opname)
63+
def wrapper(self, other):
64+
65+
if isinstance(other, str):
66+
try:
67+
# GH#18435 strings get a pass from tzawareness compat
68+
other = self._scalar_from_string(other)
69+
except ValueError:
70+
# failed to parse as Timestamp/Timedelta/Period
71+
return invalid_comparison(self, other, op)
72+
73+
if isinstance(other, self._recognized_scalars) or other is NaT:
74+
other = self._scalar_type(other)
75+
self._check_compatible_with(other)
76+
77+
other_i8 = self._unbox_scalar(other)
78+
79+
result = op(self.view("i8"), other_i8)
80+
if isna(other):
81+
result.fill(nat_result)
82+
83+
elif not is_list_like(other):
84+
return invalid_comparison(self, other, op)
85+
86+
elif len(other) != len(self):
87+
raise ValueError("Lengths must match")
88+
89+
else:
90+
if isinstance(other, list):
91+
# TODO: could use pd.Index to do inference?
92+
other = np.array(other)
93+
94+
if not isinstance(other, (np.ndarray, type(self))):
95+
return invalid_comparison(self, other, op)
96+
97+
if is_object_dtype(other):
98+
# We have to use comp_method_OBJECT_ARRAY instead of numpy
99+
# comparison otherwise it would fail to raise when
100+
# comparing tz-aware and tz-naive
101+
with np.errstate(all="ignore"):
102+
result = ops.comp_method_OBJECT_ARRAY(
103+
op, self.astype(object), other
104+
)
105+
o_mask = isna(other)
106+
107+
elif not type(self)._is_recognized_dtype(other.dtype):
108+
return invalid_comparison(self, other, op)
109+
110+
else:
111+
# For PeriodDType this casting is unnecessary
112+
other = type(self)._from_sequence(other)
113+
self._check_compatible_with(other)
114+
115+
result = op(self.view("i8"), other.view("i8"))
116+
o_mask = other._isnan
117+
118+
if o_mask.any():
119+
result[o_mask] = nat_result
120+
121+
if self._hasnans:
122+
result[self._isnan] = nat_result
123+
124+
return result
125+
126+
return set_function_name(wrapper, opname, cls)
127+
128+
53129
class AttributesMixin:
54130
_data: np.ndarray
55131

@@ -934,6 +1010,7 @@ def _is_unique(self):
9341010

9351011
# ------------------------------------------------------------------
9361012
# Arithmetic Methods
1013+
_create_comparison_method = classmethod(_datetimelike_array_cmp)
9371014

9381015
# pow is invalid for all three subclasses; TimedeltaArray will override
9391016
# the multiplication and division ops
@@ -1485,6 +1562,8 @@ def mean(self, skipna=True):
14851562
return self._box_func(result)
14861563

14871564

1565+
DatetimeLikeArrayMixin._add_comparison_ops()
1566+
14881567
# -------------------------------------------------------------------
14891568
# Shared Constructor Helpers
14901569

pandas/core/arrays/datetimes.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
timezones,
1919
tzconversion,
2020
)
21-
import pandas.compat as compat
2221
from pandas.errors import PerformanceWarning
2322

2423
from pandas.core.dtypes.common import (
@@ -32,7 +31,6 @@
3231
is_dtype_equal,
3332
is_extension_array_dtype,
3433
is_float_dtype,
35-
is_list_like,
3634
is_object_dtype,
3735
is_period_dtype,
3836
is_string_dtype,
@@ -43,13 +41,10 @@
4341
from pandas.core.dtypes.generic import ABCIndexClass, ABCPandasArray, ABCSeries
4442
from pandas.core.dtypes.missing import isna
4543

46-
from pandas.core import ops
4744
from pandas.core.algorithms import checked_add_with_arr
4845
from pandas.core.arrays import datetimelike as dtl
4946
from pandas.core.arrays._ranges import generate_regular_range
5047
import pandas.core.common as com
51-
from pandas.core.ops.common import unpack_zerodim_and_defer
52-
from pandas.core.ops.invalid import invalid_comparison
5348

5449
from pandas.tseries.frequencies import get_period_alias, to_offset
5550
from pandas.tseries.offsets import Day, Tick
@@ -131,81 +126,6 @@ def f(self):
131126
return property(f)
132127

133128

134-
def _dt_array_cmp(cls, op):
135-
"""
136-
Wrap comparison operations to convert datetime-like to datetime64
137-
"""
138-
opname = f"__{op.__name__}__"
139-
nat_result = opname == "__ne__"
140-
141-
@unpack_zerodim_and_defer(opname)
142-
def wrapper(self, other):
143-
144-
if isinstance(other, str):
145-
try:
146-
# GH#18435 strings get a pass from tzawareness compat
147-
other = self._scalar_from_string(other)
148-
except ValueError:
149-
# string that cannot be parsed to Timestamp
150-
return invalid_comparison(self, other, op)
151-
152-
if isinstance(other, self._recognized_scalars) or other is NaT:
153-
other = self._scalar_type(other)
154-
self._assert_tzawareness_compat(other)
155-
156-
other_i8 = other.value
157-
158-
result = op(self.view("i8"), other_i8)
159-
if isna(other):
160-
result.fill(nat_result)
161-
162-
elif not is_list_like(other):
163-
return invalid_comparison(self, other, op)
164-
165-
elif len(other) != len(self):
166-
raise ValueError("Lengths must match")
167-
168-
else:
169-
if isinstance(other, list):
170-
other = np.array(other)
171-
172-
if not isinstance(other, (np.ndarray, cls)):
173-
# Following Timestamp convention, __eq__ is all-False
174-
# and __ne__ is all True, others raise TypeError.
175-
return invalid_comparison(self, other, op)
176-
177-
if is_object_dtype(other):
178-
# We have to use comp_method_OBJECT_ARRAY instead of numpy
179-
# comparison otherwise it would fail to raise when
180-
# comparing tz-aware and tz-naive
181-
with np.errstate(all="ignore"):
182-
result = ops.comp_method_OBJECT_ARRAY(
183-
op, self.astype(object), other
184-
)
185-
o_mask = isna(other)
186-
187-
elif not cls._is_recognized_dtype(other.dtype):
188-
# e.g. is_timedelta64_dtype(other)
189-
return invalid_comparison(self, other, op)
190-
191-
else:
192-
self._assert_tzawareness_compat(other)
193-
other = type(self)._from_sequence(other)
194-
195-
result = op(self.view("i8"), other.view("i8"))
196-
o_mask = other._isnan
197-
198-
if o_mask.any():
199-
result[o_mask] = nat_result
200-
201-
if self._hasnans:
202-
result[self._isnan] = nat_result
203-
204-
return result
205-
206-
return compat.set_function_name(wrapper, opname, cls)
207-
208-
209129
class DatetimeArray(dtl.DatetimeLikeArrayMixin, dtl.TimelikeOps, dtl.DatelikeOps):
210130
"""
211131
Pandas ExtensionArray for tz-naive or tz-aware datetime data.
@@ -324,7 +244,7 @@ def __init__(self, values, dtype=_NS_DTYPE, freq=None, copy=False):
324244
raise TypeError(msg)
325245
elif values.tz:
326246
dtype = values.dtype
327-
# freq = validate_values_freq(values, freq)
247+
328248
if freq is None:
329249
freq = values.freq
330250
values = values._data
@@ -714,8 +634,6 @@ def _format_native_types(self, na_rep="NaT", date_format=None, **kwargs):
714634
# -----------------------------------------------------------------
715635
# Comparison Methods
716636

717-
_create_comparison_method = classmethod(_dt_array_cmp)
718-
719637
def _has_same_tz(self, other):
720638
zzone = self._timezone
721639

@@ -1767,9 +1685,6 @@ def to_julian_date(self):
17671685
)
17681686

17691687

1770-
DatetimeArray._add_comparison_ops()
1771-
1772-
17731688
# -------------------------------------------------------------------
17741689
# Constructor Helpers
17751690

pandas/core/arrays/period.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,13 @@
2020
period_asfreq_arr,
2121
)
2222
from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds
23-
import pandas.compat as compat
2423
from pandas.util._decorators import cache_readonly
2524

2625
from pandas.core.dtypes.common import (
2726
_TD_DTYPE,
2827
ensure_object,
2928
is_datetime64_dtype,
3029
is_float_dtype,
31-
is_list_like,
32-
is_object_dtype,
3330
is_period_dtype,
3431
pandas_dtype,
3532
)
@@ -42,12 +39,9 @@
4239
)
4340
from pandas.core.dtypes.missing import isna, notna
4441

45-
from pandas.core import ops
4642
import pandas.core.algorithms as algos
4743
from pandas.core.arrays import datetimelike as dtl
4844
import pandas.core.common as com
49-
from pandas.core.ops.common import unpack_zerodim_and_defer
50-
from pandas.core.ops.invalid import invalid_comparison
5145

5246
from pandas.tseries import frequencies
5347
from pandas.tseries.offsets import DateOffset, Tick, _delta_to_tick
@@ -64,77 +58,6 @@ def f(self):
6458
return property(f)
6559

6660

67-
def _period_array_cmp(cls, op):
68-
"""
69-
Wrap comparison operations to convert Period-like to PeriodDtype
70-
"""
71-
opname = f"__{op.__name__}__"
72-
nat_result = opname == "__ne__"
73-
74-
@unpack_zerodim_and_defer(opname)
75-
def wrapper(self, other):
76-
77-
if isinstance(other, str):
78-
try:
79-
other = self._scalar_from_string(other)
80-
except ValueError:
81-
# string that can't be parsed as Period
82-
return invalid_comparison(self, other, op)
83-
84-
if isinstance(other, self._recognized_scalars) or other is NaT:
85-
other = self._scalar_type(other)
86-
self._check_compatible_with(other)
87-
88-
other_i8 = self._unbox_scalar(other)
89-
90-
result = op(self.view("i8"), other_i8)
91-
if isna(other):
92-
result.fill(nat_result)
93-
94-
elif not is_list_like(other):
95-
return invalid_comparison(self, other, op)
96-
97-
elif len(other) != len(self):
98-
raise ValueError("Lengths must match")
99-
100-
else:
101-
if isinstance(other, list):
102-
# TODO: could use pd.Index to do inference?
103-
other = np.array(other)
104-
105-
if not isinstance(other, (np.ndarray, cls)):
106-
return invalid_comparison(self, other, op)
107-
108-
if is_object_dtype(other):
109-
with np.errstate(all="ignore"):
110-
result = ops.comp_method_OBJECT_ARRAY(
111-
op, self.astype(object), other
112-
)
113-
o_mask = isna(other)
114-
115-
elif not cls._is_recognized_dtype(other.dtype):
116-
# e.g. is_timedelta64_dtype(other)
117-
return invalid_comparison(self, other, op)
118-
119-
else:
120-
assert isinstance(other, cls), type(other)
121-
122-
self._check_compatible_with(other)
123-
124-
result = op(self.view("i8"), other.view("i8"))
125-
o_mask = other._isnan
126-
127-
if o_mask.any():
128-
result[o_mask] = nat_result
129-
130-
if self._hasnans:
131-
result[self._isnan] = nat_result
132-
133-
return result
134-
135-
return compat.set_function_name(wrapper, opname, cls)
136-
137-
13861
class PeriodArray(dtl.DatetimeLikeArrayMixin, dtl.DatelikeOps):
13962
"""
14063
Pandas ExtensionArray for storing Period data.
@@ -639,7 +562,6 @@ def astype(self, dtype, copy=True):
639562

640563
# ------------------------------------------------------------------
641564
# Arithmetic Methods
642-
_create_comparison_method = classmethod(_period_array_cmp)
643565

644566
def _sub_datelike(self, other):
645567
assert other is not NaT
@@ -810,9 +732,6 @@ def _check_timedeltalike_freq_compat(self, other):
810732
raise raise_on_incompatible(self, other)
811733

812734

813-
PeriodArray._add_comparison_ops()
814-
815-
816735
def raise_on_incompatible(left, right):
817736
"""
818737
Helper function to render a consistent error message when raising

0 commit comments

Comments
 (0)