Skip to content

Commit 319a6d3

Browse files
authored
REF: use unpack_zerodim_and_defer on EA methods (#34042)
1 parent 23dcab5 commit 319a6d3

File tree

6 files changed

+28
-33
lines changed

6 files changed

+28
-33
lines changed

pandas/core/arrays/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pandas.core.dtypes.cast import maybe_cast_to_extension_array
2323
from pandas.core.dtypes.common import is_array_like, is_list_like, pandas_dtype
2424
from pandas.core.dtypes.dtypes import ExtensionDtype
25-
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
25+
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
2626
from pandas.core.dtypes.missing import isna
2727

2828
from pandas.core import ops
@@ -1273,7 +1273,7 @@ def convert_values(param):
12731273
ovalues = [param] * len(self)
12741274
return ovalues
12751275

1276-
if isinstance(other, (ABCSeries, ABCIndexClass)):
1276+
if isinstance(other, (ABCSeries, ABCIndexClass, ABCDataFrame)):
12771277
# rely on pandas to unbox and dispatch to us
12781278
return NotImplemented
12791279

pandas/core/arrays/boolean.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
pandas_dtype,
2121
)
2222
from pandas.core.dtypes.dtypes import register_extension_dtype
23-
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
2423
from pandas.core.dtypes.missing import isna
2524

2625
from pandas.core import ops
@@ -559,13 +558,10 @@ def all(self, skipna: bool = True, **kwargs):
559558

560559
@classmethod
561560
def _create_logical_method(cls, op):
561+
@ops.unpack_zerodim_and_defer(op.__name__)
562562
def logical_method(self, other):
563-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
564-
# Rely on pandas to unbox and dispatch to us.
565-
return NotImplemented
566563

567564
assert op.__name__ in {"or_", "ror_", "and_", "rand_", "xor", "rxor"}
568-
other = lib.item_from_zerodim(other)
569565
other_is_booleanarray = isinstance(other, BooleanArray)
570566
other_is_scalar = lib.is_scalar(other)
571567
mask = None
@@ -605,16 +601,14 @@ def logical_method(self, other):
605601

606602
@classmethod
607603
def _create_comparison_method(cls, op):
604+
@ops.unpack_zerodim_and_defer(op.__name__)
608605
def cmp_method(self, other):
609606
from pandas.arrays import IntegerArray
610607

611-
if isinstance(
612-
other, (ABCDataFrame, ABCSeries, ABCIndexClass, IntegerArray)
613-
):
608+
if isinstance(other, IntegerArray):
614609
# Rely on pandas to unbox and dispatch to us.
615610
return NotImplemented
616611

617-
other = lib.item_from_zerodim(other)
618612
mask = None
619613

620614
if isinstance(other, BooleanArray):
@@ -693,13 +687,8 @@ def _maybe_mask_result(self, result, mask, other, op_name: str):
693687
def _create_arithmetic_method(cls, op):
694688
op_name = op.__name__
695689

690+
@ops.unpack_zerodim_and_defer(op_name)
696691
def boolean_arithmetic_method(self, other):
697-
698-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
699-
# Rely on pandas to unbox and dispatch to us.
700-
return NotImplemented
701-
702-
other = lib.item_from_zerodim(other)
703692
mask = None
704693

705694
if isinstance(other, BooleanArray):

pandas/core/arrays/numpy_.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from pandas.util._validators import validate_fillna_kwargs
1212

1313
from pandas.core.dtypes.dtypes import ExtensionDtype
14-
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
1514
from pandas.core.dtypes.inference import is_array_like
1615
from pandas.core.dtypes.missing import isna
1716

1817
from pandas import compat
19-
from pandas.core import nanops
18+
from pandas.core import nanops, ops
2019
from pandas.core.algorithms import searchsorted
2120
from pandas.core.array_algos import masked_reductions
2221
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
@@ -436,11 +435,9 @@ def __invert__(self):
436435

437436
@classmethod
438437
def _create_arithmetic_method(cls, op):
438+
@ops.unpack_zerodim_and_defer(op.__name__)
439439
def arithmetic_method(self, other):
440-
if isinstance(other, (ABCIndexClass, ABCSeries)):
441-
return NotImplemented
442-
443-
elif isinstance(other, cls):
440+
if isinstance(other, cls):
444441
other = other._ndarray
445442

446443
with np.errstate(all="ignore"):

pandas/core/arrays/string_.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from pandas.core.dtypes.base import ExtensionDtype, register_extension_dtype
99
from pandas.core.dtypes.common import pandas_dtype
10-
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
1110
from pandas.core.dtypes.inference import is_array_like
1211

1312
from pandas import compat
@@ -312,15 +311,14 @@ def memory_usage(self, deep=False):
312311
@classmethod
313312
def _create_arithmetic_method(cls, op):
314313
# Note: this handles both arithmetic and comparison methods.
314+
315+
@ops.unpack_zerodim_and_defer(op.__name__)
315316
def method(self, other):
316317
from pandas.arrays import BooleanArray
317318

318319
assert op.__name__ in ops.ARITHMETIC_BINOPS | ops.COMPARISON_BINOPS
319320

320-
if isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame)):
321-
return NotImplemented
322-
323-
elif isinstance(other, cls):
321+
if isinstance(other, cls):
324322
other = other._ndarray
325323

326324
mask = isna(self) | isna(other)

pandas/tests/extension/base/ops.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,13 @@ def test_error(self, data, all_arithmetic_operators):
114114
with pytest.raises(AttributeError):
115115
getattr(data, op_name)
116116

117-
def test_direct_arith_with_series_returns_not_implemented(self, data):
118-
# EAs should return NotImplemented for ops with Series.
117+
@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
118+
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
119+
# EAs should return NotImplemented for ops with Series/DataFrame
119120
# Pandas takes care of unboxing the series and calling the EA's op.
120121
other = pd.Series(data)
122+
if box is pd.DataFrame:
123+
other = other.to_frame()
121124
if hasattr(data, "__add__"):
122125
result = data.__add__(other)
123126
assert result is NotImplemented
@@ -156,10 +159,14 @@ def test_compare_array(self, data, all_compare_operators):
156159
other = pd.Series([data[0]] * len(data))
157160
self._compare_other(s, data, op_name, other)
158161

159-
def test_direct_arith_with_series_returns_not_implemented(self, data):
160-
# EAs should return NotImplemented for ops with Series.
162+
@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
163+
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
164+
# EAs should return NotImplemented for ops with Series/DataFrame
161165
# Pandas takes care of unboxing the series and calling the EA's op.
162166
other = pd.Series(data)
167+
if box is pd.DataFrame:
168+
other = other.to_frame()
169+
163170
if hasattr(data, "__eq__"):
164171
result = data.__eq__(other)
165172
assert result is NotImplemented

pandas/tests/extension/test_period.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,13 @@ def test_add_series_with_extension_array(self, data):
126126
def test_error(self):
127127
pass
128128

129-
def test_direct_arith_with_series_returns_not_implemented(self, data):
129+
@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
130+
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
130131
# Override to use __sub__ instead of __add__
131132
other = pd.Series(data)
133+
if box is pd.DataFrame:
134+
other = other.to_frame()
135+
132136
result = data.__sub__(other)
133137
assert result is NotImplemented
134138

0 commit comments

Comments
 (0)