Skip to content

Commit 9bdc58e

Browse files
jbrockmendeljreback
authored andcommitted
REF: move EA wrapping/unwrapping to indexes.extensions (#30648)
1 parent c82ddcd commit 9bdc58e

File tree

6 files changed

+97
-98
lines changed

6 files changed

+97
-98
lines changed

pandas/core/indexes/category.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import pandas.core.common as com
3030
import pandas.core.indexes.base as ibase
3131
from pandas.core.indexes.base import Index, _index_shared_docs, maybe_extract_name
32+
from pandas.core.indexes.extension import make_wrapped_comparison_op
3233
import pandas.core.missing as missing
3334
from pandas.core.ops import get_op_result_name
3435

@@ -876,14 +877,7 @@ def _add_comparison_methods(cls):
876877
def _make_compare(op):
877878
opname = f"__{op.__name__}__"
878879

879-
def _evaluate_compare(self, other):
880-
with np.errstate(all="ignore"):
881-
result = op(self.array, other)
882-
if isinstance(result, ABCSeries):
883-
# Dispatch to pd.Categorical returned NotImplemented
884-
# and we got a Series back; down-cast to ndarray
885-
result = result._values
886-
return result
880+
_evaluate_compare = make_wrapped_comparison_op(opname)
887881

888882
return compat.set_function_name(_evaluate_compare, opname, cls)
889883

pandas/core/indexes/datetimelike.py

+19-87
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
2727

28-
from pandas.core import algorithms, ops
28+
from pandas.core import algorithms
2929
from pandas.core.accessor import PandasDelegate
3030
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
3131
from pandas.core.arrays.datetimelike import (
@@ -40,21 +40,11 @@
4040

4141
from pandas.tseries.frequencies import DateOffset, to_offset
4242

43-
from .extension import inherit_names
43+
from .extension import inherit_names, make_wrapped_arith_op, make_wrapped_comparison_op
4444

4545
_index_doc_kwargs = dict(ibase._index_doc_kwargs)
4646

4747

48-
def _make_wrapped_arith_op(opname):
49-
def method(self, other):
50-
meth = getattr(self._data, opname)
51-
result = meth(maybe_unwrap_index(other))
52-
return wrap_arithmetic_op(self, other, result)
53-
54-
method.__name__ = opname
55-
return method
56-
57-
5848
def _join_i8_wrapper(joinf, with_indexers: bool = True):
5949
"""
6050
Create the join wrapper methods.
@@ -125,19 +115,7 @@ def _create_comparison_method(cls, op):
125115
"""
126116
Create a comparison method that dispatches to ``cls.values``.
127117
"""
128-
129-
def wrapper(self, other):
130-
if isinstance(other, ABCSeries):
131-
# the arrays defer to Series for comparison ops but the indexes
132-
# don't, so we have to unwrap here.
133-
other = other._values
134-
135-
result = op(self._data, maybe_unwrap_index(other))
136-
return result
137-
138-
wrapper.__doc__ = op.__doc__
139-
wrapper.__name__ = f"__{op.__name__}__"
140-
return wrapper
118+
return make_wrapped_comparison_op(f"__{op.__name__}__")
141119

142120
# ------------------------------------------------------------------------
143121
# Abstract data attributes
@@ -467,22 +445,22 @@ def _convert_scalar_indexer(self, key, kind=None):
467445

468446
return super()._convert_scalar_indexer(key, kind=kind)
469447

470-
__add__ = _make_wrapped_arith_op("__add__")
471-
__radd__ = _make_wrapped_arith_op("__radd__")
472-
__sub__ = _make_wrapped_arith_op("__sub__")
473-
__rsub__ = _make_wrapped_arith_op("__rsub__")
474-
__pow__ = _make_wrapped_arith_op("__pow__")
475-
__rpow__ = _make_wrapped_arith_op("__rpow__")
476-
__mul__ = _make_wrapped_arith_op("__mul__")
477-
__rmul__ = _make_wrapped_arith_op("__rmul__")
478-
__floordiv__ = _make_wrapped_arith_op("__floordiv__")
479-
__rfloordiv__ = _make_wrapped_arith_op("__rfloordiv__")
480-
__mod__ = _make_wrapped_arith_op("__mod__")
481-
__rmod__ = _make_wrapped_arith_op("__rmod__")
482-
__divmod__ = _make_wrapped_arith_op("__divmod__")
483-
__rdivmod__ = _make_wrapped_arith_op("__rdivmod__")
484-
__truediv__ = _make_wrapped_arith_op("__truediv__")
485-
__rtruediv__ = _make_wrapped_arith_op("__rtruediv__")
448+
__add__ = make_wrapped_arith_op("__add__")
449+
__radd__ = make_wrapped_arith_op("__radd__")
450+
__sub__ = make_wrapped_arith_op("__sub__")
451+
__rsub__ = make_wrapped_arith_op("__rsub__")
452+
__pow__ = make_wrapped_arith_op("__pow__")
453+
__rpow__ = make_wrapped_arith_op("__rpow__")
454+
__mul__ = make_wrapped_arith_op("__mul__")
455+
__rmul__ = make_wrapped_arith_op("__rmul__")
456+
__floordiv__ = make_wrapped_arith_op("__floordiv__")
457+
__rfloordiv__ = make_wrapped_arith_op("__rfloordiv__")
458+
__mod__ = make_wrapped_arith_op("__mod__")
459+
__rmod__ = make_wrapped_arith_op("__rmod__")
460+
__divmod__ = make_wrapped_arith_op("__divmod__")
461+
__rdivmod__ = make_wrapped_arith_op("__rdivmod__")
462+
__truediv__ = make_wrapped_arith_op("__truediv__")
463+
__rtruediv__ = make_wrapped_arith_op("__rtruediv__")
486464

487465
def isin(self, values, level=None):
488466
"""
@@ -864,55 +842,13 @@ def _wrap_joined_index(self, joined, other):
864842
return self._simple_new(joined, name, **kwargs)
865843

866844

867-
def wrap_arithmetic_op(self, other, result):
868-
if result is NotImplemented:
869-
return NotImplemented
870-
871-
if isinstance(result, tuple):
872-
# divmod, rdivmod
873-
assert len(result) == 2
874-
return (
875-
wrap_arithmetic_op(self, other, result[0]),
876-
wrap_arithmetic_op(self, other, result[1]),
877-
)
878-
879-
if not isinstance(result, Index):
880-
# Index.__new__ will choose appropriate subclass for dtype
881-
result = Index(result)
882-
883-
res_name = ops.get_op_result_name(self, other)
884-
result.name = res_name
885-
return result
886-
887-
888-
def maybe_unwrap_index(obj):
889-
"""
890-
If operating against another Index object, we need to unwrap the underlying
891-
data before deferring to the DatetimeArray/TimedeltaArray/PeriodArray
892-
implementation, otherwise we will incorrectly return NotImplemented.
893-
894-
Parameters
895-
----------
896-
obj : object
897-
898-
Returns
899-
-------
900-
unwrapped object
901-
"""
902-
if isinstance(obj, ABCIndexClass):
903-
return obj._data
904-
return obj
905-
906-
907845
class DatetimelikeDelegateMixin(PandasDelegate):
908846
"""
909847
Delegation mechanism, specific for Datetime, Timedelta, and Period types.
910848
911849
Functionality is delegated from the Index class to an Array class. A
912850
few things can be customized
913851
914-
* _delegate_class : type
915-
The class being delegated to.
916852
* _delegated_methods, delegated_properties : List
917853
The list of property / method names being delagated.
918854
* raw_methods : Set
@@ -929,10 +865,6 @@ class DatetimelikeDelegateMixin(PandasDelegate):
929865
_raw_properties: Set[str] = set()
930866
_data: ExtensionArray
931867

932-
@property
933-
def _delegate_class(self):
934-
raise AbstractMethodError
935-
936868
def _delegate_property_get(self, name, *args, **kwargs):
937869
result = getattr(self._data, name)
938870
if name not in self._raw_properties:

pandas/core/indexes/datetimes.py

-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ class DatetimeDelegateMixin(DatetimelikeDelegateMixin):
8686
| set(_extra_raw_properties)
8787
)
8888
_raw_methods = set(_extra_raw_methods)
89-
_delegate_class = DatetimeArray
9089

9190

9291
@inherit_names(["_timezone", "is_normalized", "_resolution"], DatetimeArray, cache=True)

pandas/core/indexes/extension.py

+76
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55

66
from pandas.util._decorators import cache_readonly
77

8+
from pandas.core.dtypes.generic import ABCSeries
9+
10+
from pandas.core.ops import get_op_result_name
11+
12+
from .base import Index
13+
814

915
def inherit_from_data(name: str, delegate, cache: bool = False):
1016
"""
@@ -76,3 +82,73 @@ def wrapper(cls):
7682
return cls
7783

7884
return wrapper
85+
86+
87+
def make_wrapped_comparison_op(opname):
88+
"""
89+
Create a comparison method that dispatches to ``._data``.
90+
"""
91+
92+
def wrapper(self, other):
93+
if isinstance(other, ABCSeries):
94+
# the arrays defer to Series for comparison ops but the indexes
95+
# don't, so we have to unwrap here.
96+
other = other._values
97+
98+
other = _maybe_unwrap_index(other)
99+
100+
op = getattr(self._data, opname)
101+
return op(other)
102+
103+
wrapper.__name__ = opname
104+
return wrapper
105+
106+
107+
def make_wrapped_arith_op(opname):
108+
def method(self, other):
109+
meth = getattr(self._data, opname)
110+
result = meth(_maybe_unwrap_index(other))
111+
return _wrap_arithmetic_op(self, other, result)
112+
113+
method.__name__ = opname
114+
return method
115+
116+
117+
def _wrap_arithmetic_op(self, other, result):
118+
if result is NotImplemented:
119+
return NotImplemented
120+
121+
if isinstance(result, tuple):
122+
# divmod, rdivmod
123+
assert len(result) == 2
124+
return (
125+
_wrap_arithmetic_op(self, other, result[0]),
126+
_wrap_arithmetic_op(self, other, result[1]),
127+
)
128+
129+
if not isinstance(result, Index):
130+
# Index.__new__ will choose appropriate subclass for dtype
131+
result = Index(result)
132+
133+
res_name = get_op_result_name(self, other)
134+
result.name = res_name
135+
return result
136+
137+
138+
def _maybe_unwrap_index(obj):
139+
"""
140+
If operating against another Index object, we need to unwrap the underlying
141+
data before deferring to the DatetimeArray/TimedeltaArray/PeriodArray
142+
implementation, otherwise we will incorrectly return NotImplemented.
143+
144+
Parameters
145+
----------
146+
obj : object
147+
148+
Returns
149+
-------
150+
unwrapped object
151+
"""
152+
if isinstance(obj, Index):
153+
return obj._data
154+
return obj

pandas/core/indexes/period.py

-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ class PeriodDelegateMixin(DatetimelikeDelegateMixin):
6666
Delegate from PeriodIndex to PeriodArray.
6767
"""
6868

69-
_delegate_class = PeriodArray
7069
_raw_methods = {"_format_native_types"}
7170
_raw_properties = {"is_leap_year", "freq"}
7271

pandas/core/indexes/timedeltas.py

-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class TimedeltaDelegateMixin(DatetimelikeDelegateMixin):
4141
# Some are "raw" methods, the result is not re-boxed in an Index
4242
# We also have a few "extra" attrs, which may or may not be raw,
4343
# which we don't want to expose in the .dt accessor.
44-
_delegate_class = TimedeltaArray
4544
_raw_properties = {"components", "_box_func"}
4645
_raw_methods = {"to_pytimedelta", "sum", "std", "median", "_format_native_types"}
4746

0 commit comments

Comments
 (0)