Skip to content

Commit 257ad4e

Browse files
authored
TYP/REF: define comparison methods non-dynamically (#36930)
1 parent 8e1cc56 commit 257ad4e

File tree

8 files changed

+138
-135
lines changed

8 files changed

+138
-135
lines changed

pandas/core/arraylike.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Methods that can be shared by many array-like classes or subclasses:
3+
Series
4+
Index
5+
ExtensionArray
6+
"""
7+
import operator
8+
9+
from pandas.errors import AbstractMethodError
10+
11+
from pandas.core.ops.common import unpack_zerodim_and_defer
12+
13+
14+
class OpsMixin:
15+
# -------------------------------------------------------------
16+
# Comparisons
17+
18+
def _cmp_method(self, other, op):
19+
raise AbstractMethodError(self)
20+
21+
@unpack_zerodim_and_defer("__eq__")
22+
def __eq__(self, other):
23+
return self._cmp_method(other, operator.eq)
24+
25+
@unpack_zerodim_and_defer("__ne__")
26+
def __ne__(self, other):
27+
return self._cmp_method(other, operator.ne)
28+
29+
@unpack_zerodim_and_defer("__lt__")
30+
def __lt__(self, other):
31+
return self._cmp_method(other, operator.lt)
32+
33+
@unpack_zerodim_and_defer("__le__")
34+
def __le__(self, other):
35+
return self._cmp_method(other, operator.le)
36+
37+
@unpack_zerodim_and_defer("__gt__")
38+
def __gt__(self, other):
39+
return self._cmp_method(other, operator.gt)
40+
41+
@unpack_zerodim_and_defer("__ge__")
42+
def __ge__(self, other):
43+
return self._cmp_method(other, operator.ge)

pandas/core/arrays/datetimelike.py

+31-48
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
round_nsint64,
2525
)
2626
from pandas._typing import DatetimeLikeScalar, DtypeObj
27-
from pandas.compat import set_function_name
2827
from pandas.compat.numpy import function as nv
2928
from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
3029
from pandas.util._decorators import Appender, Substitution, cache_readonly
@@ -51,8 +50,8 @@
5150

5251
from pandas.core import nanops, ops
5352
from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts
53+
from pandas.core.arraylike import OpsMixin
5454
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
55-
from pandas.core.arrays.base import ExtensionOpsMixin
5655
import pandas.core.common as com
5756
from pandas.core.construction import array, extract_array
5857
from pandas.core.indexers import check_array_indexer, check_setitem_lengths
@@ -73,46 +72,6 @@ class InvalidComparison(Exception):
7372
pass
7473

7574

76-
def _datetimelike_array_cmp(cls, op):
77-
"""
78-
Wrap comparison operations to convert Timestamp/Timedelta/Period-like to
79-
boxed scalars/arrays.
80-
"""
81-
opname = f"__{op.__name__}__"
82-
nat_result = opname == "__ne__"
83-
84-
@unpack_zerodim_and_defer(opname)
85-
def wrapper(self, other):
86-
if self.ndim > 1 and getattr(other, "shape", None) == self.shape:
87-
# TODO: handle 2D-like listlikes
88-
return op(self.ravel(), other.ravel()).reshape(self.shape)
89-
90-
try:
91-
other = self._validate_comparison_value(other, opname)
92-
except InvalidComparison:
93-
return invalid_comparison(self, other, op)
94-
95-
dtype = getattr(other, "dtype", None)
96-
if is_object_dtype(dtype):
97-
# We have to use comp_method_OBJECT_ARRAY instead of numpy
98-
# comparison otherwise it would fail to raise when
99-
# comparing tz-aware and tz-naive
100-
with np.errstate(all="ignore"):
101-
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
102-
return result
103-
104-
other_i8 = self._unbox(other)
105-
result = op(self.asi8, other_i8)
106-
107-
o_mask = isna(other)
108-
if self._hasnans | np.any(o_mask):
109-
result[self._isnan | o_mask] = nat_result
110-
111-
return result
112-
113-
return set_function_name(wrapper, opname, cls)
114-
115-
11675
class AttributesMixin:
11776
_data: np.ndarray
11877

@@ -426,9 +385,7 @@ def _with_freq(self, freq):
426385
DatetimeLikeArrayT = TypeVar("DatetimeLikeArrayT", bound="DatetimeLikeArrayMixin")
427386

428387

429-
class DatetimeLikeArrayMixin(
430-
ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray
431-
):
388+
class DatetimeLikeArrayMixin(OpsMixin, AttributesMixin, NDArrayBackedExtensionArray):
432389
"""
433390
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray
434391
@@ -1093,7 +1050,35 @@ def _is_unique(self):
10931050

10941051
# ------------------------------------------------------------------
10951052
# Arithmetic Methods
1096-
_create_comparison_method = classmethod(_datetimelike_array_cmp)
1053+
1054+
def _cmp_method(self, other, op):
1055+
if self.ndim > 1 and getattr(other, "shape", None) == self.shape:
1056+
# TODO: handle 2D-like listlikes
1057+
return op(self.ravel(), other.ravel()).reshape(self.shape)
1058+
1059+
try:
1060+
other = self._validate_comparison_value(other, f"__{op.__name__}__")
1061+
except InvalidComparison:
1062+
return invalid_comparison(self, other, op)
1063+
1064+
dtype = getattr(other, "dtype", None)
1065+
if is_object_dtype(dtype):
1066+
# We have to use comp_method_OBJECT_ARRAY instead of numpy
1067+
# comparison otherwise it would fail to raise when
1068+
# comparing tz-aware and tz-naive
1069+
with np.errstate(all="ignore"):
1070+
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
1071+
return result
1072+
1073+
other_i8 = self._unbox(other)
1074+
result = op(self.asi8, other_i8)
1075+
1076+
o_mask = isna(other)
1077+
if self._hasnans | np.any(o_mask):
1078+
nat_result = op is operator.ne
1079+
result[self._isnan | o_mask] = nat_result
1080+
1081+
return result
10971082

10981083
# pow is invalid for all three subclasses; TimedeltaArray will override
10991084
# the multiplication and division ops
@@ -1582,8 +1567,6 @@ def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwarg
15821567
return self._from_backing_data(result.astype("i8"))
15831568

15841569

1585-
DatetimeLikeArrayMixin._add_comparison_ops()
1586-
15871570
# -------------------------------------------------------------------
15881571
# Shared Constructor Helpers
15891572

pandas/core/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pandas.core import algorithms, common as com
3131
from pandas.core.accessor import DirNamesMixin
3232
from pandas.core.algorithms import duplicated, unique1d, value_counts
33+
from pandas.core.arraylike import OpsMixin
3334
from pandas.core.arrays import ExtensionArray
3435
from pandas.core.construction import create_series_with_explicit_dtype
3536
import pandas.core.nanops as nanops
@@ -587,7 +588,7 @@ def _is_builtin_func(self, arg):
587588
return self._builtin_table.get(arg, arg)
588589

589590

590-
class IndexOpsMixin:
591+
class IndexOpsMixin(OpsMixin):
591592
"""
592593
Common ops mixin to support a unified interface / docs for Series / Index
593594
"""

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1764,7 +1764,7 @@ def _drop_labels_or_levels(self, keys, axis: int = 0):
17641764
# ----------------------------------------------------------------------
17651765
# Iteration
17661766

1767-
def __hash__(self):
1767+
def __hash__(self) -> int:
17681768
raise TypeError(
17691769
f"{repr(type(self).__name__)} objects are mutable, "
17701770
f"thus they cannot be hashed"

pandas/core/indexes/base.py

+30-45
Original file line numberDiff line numberDiff line change
@@ -121,41 +121,6 @@
121121
str_t = str
122122

123123

124-
def _make_comparison_op(op, cls):
125-
def cmp_method(self, other):
126-
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)):
127-
if other.ndim > 0 and len(self) != len(other):
128-
raise ValueError("Lengths must match to compare")
129-
130-
if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):
131-
left = type(other)(self._values, dtype=other.dtype)
132-
return op(left, other)
133-
elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
134-
# e.g. PeriodArray
135-
with np.errstate(all="ignore"):
136-
result = op(self._values, other)
137-
138-
elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
139-
# don't pass MultiIndex
140-
with np.errstate(all="ignore"):
141-
result = ops.comp_method_OBJECT_ARRAY(op, self._values, other)
142-
143-
elif is_interval_dtype(self.dtype):
144-
with np.errstate(all="ignore"):
145-
result = op(self._values, np.asarray(other))
146-
147-
else:
148-
with np.errstate(all="ignore"):
149-
result = ops.comparison_op(self._values, np.asarray(other), op)
150-
151-
if is_bool_dtype(result):
152-
return result
153-
return ops.invalid_comparison(self, other, op)
154-
155-
name = f"__{op.__name__}__"
156-
return set_function_name(cmp_method, name, cls)
157-
158-
159124
def _make_arithmetic_op(op, cls):
160125
def index_arithmetic_method(self, other):
161126
if isinstance(other, (ABCSeries, ABCDataFrame, ABCTimedeltaIndex)):
@@ -5400,17 +5365,38 @@ def drop(self, labels, errors: str_t = "raise"):
54005365
# --------------------------------------------------------------------
54015366
# Generated Arithmetic, Comparison, and Unary Methods
54025367

5403-
@classmethod
5404-
def _add_comparison_methods(cls):
5368+
def _cmp_method(self, other, op):
54055369
"""
5406-
Add in comparison methods.
5370+
Wrapper used to dispatch comparison operations.
54075371
"""
5408-
cls.__eq__ = _make_comparison_op(operator.eq, cls)
5409-
cls.__ne__ = _make_comparison_op(operator.ne, cls)
5410-
cls.__lt__ = _make_comparison_op(operator.lt, cls)
5411-
cls.__gt__ = _make_comparison_op(operator.gt, cls)
5412-
cls.__le__ = _make_comparison_op(operator.le, cls)
5413-
cls.__ge__ = _make_comparison_op(operator.ge, cls)
5372+
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)):
5373+
if other.ndim > 0 and len(self) != len(other):
5374+
raise ValueError("Lengths must match to compare")
5375+
5376+
if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):
5377+
left = type(other)(self._values, dtype=other.dtype)
5378+
return op(left, other)
5379+
elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
5380+
# e.g. PeriodArray
5381+
with np.errstate(all="ignore"):
5382+
result = op(self._values, other)
5383+
5384+
elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
5385+
# don't pass MultiIndex
5386+
with np.errstate(all="ignore"):
5387+
result = ops.comp_method_OBJECT_ARRAY(op, self._values, other)
5388+
5389+
elif is_interval_dtype(self.dtype):
5390+
with np.errstate(all="ignore"):
5391+
result = op(self._values, np.asarray(other))
5392+
5393+
else:
5394+
with np.errstate(all="ignore"):
5395+
result = ops.comparison_op(self._values, np.asarray(other), op)
5396+
5397+
if is_bool_dtype(result):
5398+
return result
5399+
return ops.invalid_comparison(self, other, op)
54145400

54155401
@classmethod
54165402
def _add_numeric_methods_binary(cls):
@@ -5594,7 +5580,6 @@ def shape(self):
55945580

55955581
Index._add_numeric_methods()
55965582
Index._add_logical_methods()
5597-
Index._add_comparison_methods()
55985583

55995584

56005585
def ensure_index_from_sequences(sequences, names=None):

pandas/core/ops/__init__.py

+2-29
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020

2121
from pandas.core import algorithms
2222
from pandas.core.construction import extract_array
23-
from pandas.core.ops.array_ops import (
23+
from pandas.core.ops.array_ops import ( # noqa:F401
2424
arithmetic_op,
25+
comp_method_OBJECT_ARRAY,
2526
comparison_op,
2627
get_array_op,
2728
logical_op,
2829
)
29-
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY # noqa:F401
3030
from pandas.core.ops.common import unpack_zerodim_and_defer
3131
from pandas.core.ops.docstrings import (
3232
_arith_doc_FRAME,
@@ -323,33 +323,6 @@ def wrapper(left, right):
323323
return wrapper
324324

325325

326-
def comp_method_SERIES(cls, op, special):
327-
"""
328-
Wrapper function for Series arithmetic operations, to avoid
329-
code duplication.
330-
"""
331-
assert special # non-special uses flex_method_SERIES
332-
op_name = _get_op_name(op, special)
333-
334-
@unpack_zerodim_and_defer(op_name)
335-
def wrapper(self, other):
336-
337-
res_name = get_op_result_name(self, other)
338-
339-
if isinstance(other, ABCSeries) and not self._indexed_same(other):
340-
raise ValueError("Can only compare identically-labeled Series objects")
341-
342-
lvalues = extract_array(self, extract_numpy=True)
343-
rvalues = extract_array(other, extract_numpy=True)
344-
345-
res_values = comparison_op(lvalues, rvalues, op)
346-
347-
return self._construct_result(res_values, name=res_name)
348-
349-
wrapper.__name__ = op_name
350-
return wrapper
351-
352-
353326
def bool_method_SERIES(cls, op, special):
354327
"""
355328
Wrapper function for Series arithmetic operations, to avoid

pandas/core/ops/methods.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def _get_method_wrappers(cls):
4848
arith_method_SERIES,
4949
bool_method_SERIES,
5050
comp_method_FRAME,
51-
comp_method_SERIES,
5251
flex_comp_method_FRAME,
5352
flex_method_SERIES,
5453
)
@@ -58,7 +57,7 @@ def _get_method_wrappers(cls):
5857
arith_flex = flex_method_SERIES
5958
comp_flex = flex_method_SERIES
6059
arith_special = arith_method_SERIES
61-
comp_special = comp_method_SERIES
60+
comp_special = None
6261
bool_special = bool_method_SERIES
6362
elif issubclass(cls, ABCDataFrame):
6463
arith_flex = arith_method_FRAME
@@ -189,16 +188,18 @@ def _create_methods(cls, arith_method, comp_method, bool_method, special):
189188
new_methods["divmod"] = arith_method(cls, divmod, special)
190189
new_methods["rdivmod"] = arith_method(cls, rdivmod, special)
191190

192-
new_methods.update(
193-
dict(
194-
eq=comp_method(cls, operator.eq, special),
195-
ne=comp_method(cls, operator.ne, special),
196-
lt=comp_method(cls, operator.lt, special),
197-
gt=comp_method(cls, operator.gt, special),
198-
le=comp_method(cls, operator.le, special),
199-
ge=comp_method(cls, operator.ge, special),
191+
if comp_method is not None:
192+
# Series already has this pinned
193+
new_methods.update(
194+
dict(
195+
eq=comp_method(cls, operator.eq, special),
196+
ne=comp_method(cls, operator.ne, special),
197+
lt=comp_method(cls, operator.lt, special),
198+
gt=comp_method(cls, operator.gt, special),
199+
le=comp_method(cls, operator.le, special),
200+
ge=comp_method(cls, operator.ge, special),
201+
)
200202
)
201-
)
202203

203204
if bool_method:
204205
new_methods.update(

0 commit comments

Comments
 (0)