Skip to content

Commit 3dc59d4

Browse files
COMPAT: Return NotImplemented for subclassing (#31136)
* COMPAT: Return NotImplemented for subclassing This changes index ops to check the *type* of the argument in index ops, rather than just the dtype. This lets index subclasses take control of binary ops when they know better what the result should be. Closes #31109
1 parent 0575149 commit 3dc59d4

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

pandas/core/indexes/extension.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from pandas.compat.numpy import function as nv
99
from pandas.util._decorators import Appender, cache_readonly
1010

11-
from pandas.core.dtypes.common import ensure_platform_int, is_dtype_equal
11+
from pandas.core.dtypes.common import (
12+
ensure_platform_int,
13+
is_dtype_equal,
14+
is_object_dtype,
15+
)
1216
from pandas.core.dtypes.generic import ABCSeries
1317

1418
from pandas.core.arrays import ExtensionArray
@@ -129,6 +133,15 @@ def wrapper(self, other):
129133

130134
def make_wrapped_arith_op(opname):
131135
def method(self, other):
136+
if (
137+
isinstance(other, Index)
138+
and is_object_dtype(other.dtype)
139+
and type(other) is not Index
140+
):
141+
# We return NotImplemented for object-dtype index *subclasses* so they have
142+
# a chance to implement ops before we unwrap them.
143+
# See https://github.com/pandas-dev/pandas/issues/31109
144+
return NotImplemented
132145
meth = getattr(self._data, opname)
133146
result = meth(_maybe_unwrap_index(other))
134147
return _wrap_arithmetic_op(self, other, result)

pandas/tests/arithmetic/test_object.py

+46
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Arithmetic tests for DataFrame/Series/Index/Array classes that should
22
# behave identically.
33
# Specifically for object dtype
4+
import datetime
45
from decimal import Decimal
56
import operator
67

@@ -328,3 +329,48 @@ def test_rsub_object(self):
328329

329330
with pytest.raises(TypeError, match=msg):
330331
np.array([True, pd.Timestamp.now()]) - index
332+
333+
334+
class MyIndex(pd.Index):
335+
# Simple index subclass that tracks ops calls.
336+
337+
_calls: int
338+
339+
@classmethod
340+
def _simple_new(cls, values, name=None, dtype=None):
341+
result = object.__new__(cls)
342+
result._data = values
343+
result._index_data = values
344+
result._name = name
345+
result._calls = 0
346+
347+
return result._reset_identity()
348+
349+
def __add__(self, other):
350+
self._calls += 1
351+
return self._simple_new(self._index_data)
352+
353+
def __radd__(self, other):
354+
return self.__add__(other)
355+
356+
357+
@pytest.mark.parametrize(
358+
"other",
359+
[
360+
[datetime.timedelta(1), datetime.timedelta(2)],
361+
[datetime.datetime(2000, 1, 1), datetime.datetime(2000, 1, 2)],
362+
[pd.Period("2000"), pd.Period("2001")],
363+
["a", "b"],
364+
],
365+
ids=["timedelta", "datetime", "period", "object"],
366+
)
367+
def test_index_ops_defer_to_unknown_subclasses(other):
368+
# https://github.com/pandas-dev/pandas/issues/31109
369+
values = np.array(
370+
[datetime.date(2000, 1, 1), datetime.date(2000, 1, 2)], dtype=object
371+
)
372+
a = MyIndex._simple_new(values)
373+
other = pd.Index(other)
374+
result = other + a
375+
assert isinstance(result, MyIndex)
376+
assert a._calls == 1

0 commit comments

Comments
 (0)