Skip to content

Commit e9d0a58

Browse files
authored
REF: Share Index comparison and arithmetic methods (#43555)
1 parent e7efcca commit e9d0a58

File tree

6 files changed

+59
-151
lines changed

6 files changed

+59
-151
lines changed

pandas/core/base.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@
5252
remove_na_arraylike,
5353
)
5454

55-
from pandas.core import algorithms
55+
from pandas.core import (
56+
algorithms,
57+
ops,
58+
)
5659
from pandas.core.accessor import DirNamesMixin
5760
from pandas.core.algorithms import (
5861
duplicated,
@@ -61,7 +64,11 @@
6164
)
6265
from pandas.core.arraylike import OpsMixin
6366
from pandas.core.arrays import ExtensionArray
64-
from pandas.core.construction import create_series_with_explicit_dtype
67+
from pandas.core.construction import (
68+
create_series_with_explicit_dtype,
69+
ensure_wrapped_if_datetimelike,
70+
extract_array,
71+
)
6572
import pandas.core.nanops as nanops
6673

6774
if TYPE_CHECKING:
@@ -1238,3 +1245,23 @@ def _duplicated(
12381245
self, keep: Literal["first", "last", False] = "first"
12391246
) -> npt.NDArray[np.bool_]:
12401247
return duplicated(self._values, keep=keep)
1248+
1249+
def _arith_method(self, other, op):
1250+
res_name = ops.get_op_result_name(self, other)
1251+
1252+
lvalues = self._values
1253+
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
1254+
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
1255+
rvalues = ensure_wrapped_if_datetimelike(rvalues)
1256+
1257+
with np.errstate(all="ignore"):
1258+
result = ops.arithmetic_op(lvalues, rvalues, op)
1259+
1260+
return self._construct_result(result, name=res_name)
1261+
1262+
def _construct_result(self, result, name):
1263+
"""
1264+
Construct an appropriately-wrapped result from the ArrayLike result
1265+
of an arithmetic-like operation.
1266+
"""
1267+
raise AbstractMethodError(self)

pandas/core/indexes/base.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -6364,7 +6364,10 @@ def _cmp_method(self, other, op):
63646364
arr[self.isna()] = False
63656365
return arr
63666366
elif op in {operator.ne, operator.lt, operator.gt}:
6367-
return np.zeros(len(self), dtype=bool)
6367+
arr = np.zeros(len(self), dtype=bool)
6368+
if self._can_hold_na and not isinstance(self, ABCMultiIndex):
6369+
arr[self.isna()] = True
6370+
return arr
63686371

63696372
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)) and len(
63706373
self
@@ -6381,6 +6384,9 @@ def _cmp_method(self, other, op):
63816384
with np.errstate(all="ignore"):
63826385
result = op(self._values, other)
63836386

6387+
elif isinstance(self._values, ExtensionArray):
6388+
result = op(self._values, other)
6389+
63846390
elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
63856391
# don't pass MultiIndex
63866392
with np.errstate(all="ignore"):
@@ -6392,17 +6398,26 @@ def _cmp_method(self, other, op):
63926398

63936399
return result
63946400

6395-
def _arith_method(self, other, op):
6396-
"""
6397-
Wrapper used to dispatch arithmetic operations.
6398-
"""
6401+
def _construct_result(self, result, name):
6402+
if isinstance(result, tuple):
6403+
return (
6404+
Index._with_infer(result[0], name=name),
6405+
Index._with_infer(result[1], name=name),
6406+
)
6407+
return Index._with_infer(result, name=name)
63996408

6400-
from pandas import Series
6409+
def _arith_method(self, other, op):
6410+
if (
6411+
isinstance(other, Index)
6412+
and is_object_dtype(other.dtype)
6413+
and type(other) is not Index
6414+
):
6415+
# We return NotImplemented for object-dtype index *subclasses* so they have
6416+
# a chance to implement ops before we unwrap them.
6417+
# See https://github.com/pandas-dev/pandas/issues/31109
6418+
return NotImplemented
64016419

6402-
result = op(Series(self), other)
6403-
if isinstance(result, tuple):
6404-
return (Index._with_infer(result[0]), Index(result[1]))
6405-
return Index._with_infer(result)
6420+
return super()._arith_method(other, op)
64066421

64076422
@final
64086423
def _unary_method(self, op):

pandas/core/indexes/extension.py

+1-118
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,9 @@
2626

2727
from pandas.core.dtypes.common import (
2828
is_dtype_equal,
29-
is_object_dtype,
3029
pandas_dtype,
3130
)
32-
from pandas.core.dtypes.generic import (
33-
ABCDataFrame,
34-
ABCSeries,
35-
)
31+
from pandas.core.dtypes.generic import ABCDataFrame
3632

3733
from pandas.core.arrays import (
3834
Categorical,
@@ -45,7 +41,6 @@
4541
from pandas.core.arrays.base import ExtensionArray
4642
from pandas.core.indexers import deprecate_ndim_indexing
4743
from pandas.core.indexes.base import Index
48-
from pandas.core.ops import get_op_result_name
4944

5045
if TYPE_CHECKING:
5146

@@ -154,94 +149,6 @@ def wrapper(cls):
154149
return wrapper
155150

156151

157-
def _make_wrapped_comparison_op(opname: str):
158-
"""
159-
Create a comparison method that dispatches to ``._data``.
160-
"""
161-
162-
def wrapper(self, other):
163-
if isinstance(other, ABCSeries):
164-
# the arrays defer to Series for comparison ops but the indexes
165-
# don't, so we have to unwrap here.
166-
other = other._values
167-
168-
other = _maybe_unwrap_index(other)
169-
170-
op = getattr(self._data, opname)
171-
return op(other)
172-
173-
wrapper.__name__ = opname
174-
return wrapper
175-
176-
177-
def _make_wrapped_arith_op(opname: str):
178-
def method(self, other):
179-
if (
180-
isinstance(other, Index)
181-
and is_object_dtype(other.dtype)
182-
and type(other) is not Index
183-
):
184-
# We return NotImplemented for object-dtype index *subclasses* so they have
185-
# a chance to implement ops before we unwrap them.
186-
# See https://github.com/pandas-dev/pandas/issues/31109
187-
return NotImplemented
188-
189-
try:
190-
meth = getattr(self._data, opname)
191-
except AttributeError as err:
192-
# e.g. Categorical, IntervalArray
193-
cls = type(self).__name__
194-
raise TypeError(
195-
f"cannot perform {opname} with this index type: {cls}"
196-
) from err
197-
198-
result = meth(_maybe_unwrap_index(other))
199-
return _wrap_arithmetic_op(self, other, result)
200-
201-
method.__name__ = opname
202-
return method
203-
204-
205-
def _wrap_arithmetic_op(self, other, result):
206-
if result is NotImplemented:
207-
return NotImplemented
208-
209-
if isinstance(result, tuple):
210-
# divmod, rdivmod
211-
assert len(result) == 2
212-
return (
213-
_wrap_arithmetic_op(self, other, result[0]),
214-
_wrap_arithmetic_op(self, other, result[1]),
215-
)
216-
217-
if not isinstance(result, Index):
218-
# Index.__new__ will choose appropriate subclass for dtype
219-
result = Index(result)
220-
221-
res_name = get_op_result_name(self, other)
222-
result.name = res_name
223-
return result
224-
225-
226-
def _maybe_unwrap_index(obj):
227-
"""
228-
If operating against another Index object, we need to unwrap the underlying
229-
data before deferring to the DatetimeArray/TimedeltaArray/PeriodArray
230-
implementation, otherwise we will incorrectly return NotImplemented.
231-
232-
Parameters
233-
----------
234-
obj : object
235-
236-
Returns
237-
-------
238-
unwrapped object
239-
"""
240-
if isinstance(obj, Index):
241-
return obj._data
242-
return obj
243-
244-
245152
class ExtensionIndex(Index):
246153
"""
247154
Index subclass for indexes backed by ExtensionArray.
@@ -284,30 +191,6 @@ def _simple_new(
284191
result._reset_identity()
285192
return result
286193

287-
__eq__ = _make_wrapped_comparison_op("__eq__")
288-
__ne__ = _make_wrapped_comparison_op("__ne__")
289-
__lt__ = _make_wrapped_comparison_op("__lt__")
290-
__gt__ = _make_wrapped_comparison_op("__gt__")
291-
__le__ = _make_wrapped_comparison_op("__le__")
292-
__ge__ = _make_wrapped_comparison_op("__ge__")
293-
294-
__add__ = _make_wrapped_arith_op("__add__")
295-
__sub__ = _make_wrapped_arith_op("__sub__")
296-
__radd__ = _make_wrapped_arith_op("__radd__")
297-
__rsub__ = _make_wrapped_arith_op("__rsub__")
298-
__pow__ = _make_wrapped_arith_op("__pow__")
299-
__rpow__ = _make_wrapped_arith_op("__rpow__")
300-
__mul__ = _make_wrapped_arith_op("__mul__")
301-
__rmul__ = _make_wrapped_arith_op("__rmul__")
302-
__floordiv__ = _make_wrapped_arith_op("__floordiv__")
303-
__rfloordiv__ = _make_wrapped_arith_op("__rfloordiv__")
304-
__mod__ = _make_wrapped_arith_op("__mod__")
305-
__rmod__ = _make_wrapped_arith_op("__rmod__")
306-
__divmod__ = _make_wrapped_arith_op("__divmod__")
307-
__rdivmod__ = _make_wrapped_arith_op("__rdivmod__")
308-
__truediv__ = _make_wrapped_arith_op("__truediv__")
309-
__rtruediv__ = _make_wrapped_arith_op("__rtruediv__")
310-
311194
# ---------------------------------------------------------------------
312195
# NDarray-Like Methods
313196

pandas/core/series.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
import pandas.core.common as com
105105
from pandas.core.construction import (
106106
create_series_with_explicit_dtype,
107-
ensure_wrapped_if_datetimelike,
108107
extract_array,
109108
is_empty_data,
110109
sanitize_array,
@@ -5515,18 +5514,8 @@ def _logical_method(self, other, op):
55155514
return self._construct_result(res_values, name=res_name)
55165515

55175516
def _arith_method(self, other, op):
5518-
res_name = ops.get_op_result_name(self, other)
55195517
self, other = ops.align_method_SERIES(self, other)
5520-
5521-
lvalues = self._values
5522-
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
5523-
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
5524-
rvalues = ensure_wrapped_if_datetimelike(rvalues)
5525-
5526-
with np.errstate(all="ignore"):
5527-
result = ops.arithmetic_op(lvalues, rvalues, op)
5528-
5529-
return self._construct_result(result, name=res_name)
5518+
return base.IndexOpsMixin._arith_method(self, other, op)
55305519

55315520

55325521
Series._add_numeric_operations()

pandas/tests/arithmetic/test_datetime64.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,7 @@ def test_dti_sub_tdi(self, tz_naive_fixture):
21462146
result = dti - tdi.values
21472147
tm.assert_index_equal(result, expected)
21482148

2149-
msg = "cannot subtract DatetimeArray from"
2149+
msg = "cannot subtract a datelike from a TimedeltaArray"
21502150
with pytest.raises(TypeError, match=msg):
21512151
tdi.values - dti
21522152

@@ -2172,13 +2172,7 @@ def test_dti_isub_tdi(self, tz_naive_fixture):
21722172
result -= tdi.values
21732173
tm.assert_index_equal(result, expected)
21742174

2175-
msg = "|".join(
2176-
[
2177-
"cannot perform __neg__ with this index type:",
2178-
"ufunc subtract cannot use operands with types",
2179-
"cannot subtract DatetimeArray from",
2180-
]
2181-
)
2175+
msg = "cannot subtract a datelike from a TimedeltaArray"
21822176
with pytest.raises(TypeError, match=msg):
21832177
tdi.values -= dti
21842178

pandas/tests/arithmetic/test_period.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def test_pi_add_sub_td64_array_non_tick_raises(self):
753753

754754
with pytest.raises(TypeError, match=msg):
755755
rng - tdarr
756-
msg = r"cannot subtract PeriodArray from timedelta64\[ns\]"
756+
msg = r"cannot subtract period\[Q-DEC\]-dtype from TimedeltaArray"
757757
with pytest.raises(TypeError, match=msg):
758758
tdarr - rng
759759

0 commit comments

Comments
 (0)