Skip to content

Commit 6a927b0

Browse files
jbrockmendeljreback
authored andcommitted
CLN: match standardized dispatch logic (#27830)
1 parent d187d90 commit 6a927b0

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

pandas/core/arrays/integer.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
is_scalar,
2222
)
2323
from pandas.core.dtypes.dtypes import register_extension_dtype
24-
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
24+
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
2525
from pandas.core.dtypes.missing import isna, notna
2626

2727
from pandas.core import nanops, ops
@@ -592,25 +592,29 @@ def _values_for_argsort(self) -> np.ndarray:
592592

593593
@classmethod
594594
def _create_comparison_method(cls, op):
595-
def cmp_method(self, other):
595+
op_name = op.__name__
596596

597-
op_name = op.__name__
598-
mask = None
597+
def cmp_method(self, other):
599598

600-
if isinstance(other, (ABCSeries, ABCIndexClass)):
599+
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
601600
# Rely on pandas to unbox and dispatch to us.
602601
return NotImplemented
603602

603+
other = lib.item_from_zerodim(other)
604+
mask = None
605+
604606
if isinstance(other, IntegerArray):
605607
other, mask = other._data, other._mask
606608

607609
elif is_list_like(other):
608610
other = np.asarray(other)
609-
if other.ndim > 0 and len(self) != len(other):
611+
if other.ndim > 1:
612+
raise NotImplementedError(
613+
"can only perform ops with 1-d structures"
614+
)
615+
if len(self) != len(other):
610616
raise ValueError("Lengths must match to compare")
611617

612-
other = lib.item_from_zerodim(other)
613-
614618
# numpy will show a DeprecationWarning on invalid elementwise
615619
# comparisons, this will raise in the future
616620
with warnings.catch_warnings():
@@ -683,31 +687,31 @@ def _maybe_mask_result(self, result, mask, other, op_name):
683687

684688
@classmethod
685689
def _create_arithmetic_method(cls, op):
686-
def integer_arithmetic_method(self, other):
690+
op_name = op.__name__
687691

688-
op_name = op.__name__
689-
mask = None
692+
def integer_arithmetic_method(self, other):
690693

691-
if isinstance(other, (ABCSeries, ABCIndexClass)):
694+
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
692695
# Rely on pandas to unbox and dispatch to us.
693696
return NotImplemented
694697

695-
if getattr(other, "ndim", 0) > 1:
696-
raise NotImplementedError("can only perform ops with 1-d structures")
698+
other = lib.item_from_zerodim(other)
699+
mask = None
697700

698701
if isinstance(other, IntegerArray):
699702
other, mask = other._data, other._mask
700703

701-
elif getattr(other, "ndim", None) == 0:
702-
other = other.item()
703-
704704
elif is_list_like(other):
705705
other = np.asarray(other)
706-
if not other.ndim:
707-
other = other.item()
708-
elif other.ndim == 1:
709-
if not (is_float_dtype(other) or is_integer_dtype(other)):
710-
raise TypeError("can only perform ops with numeric values")
706+
if other.ndim > 1:
707+
raise NotImplementedError(
708+
"can only perform ops with 1-d structures"
709+
)
710+
if len(self) != len(other):
711+
raise ValueError("Lengths must match")
712+
if not (is_float_dtype(other) or is_integer_dtype(other)):
713+
raise TypeError("can only perform ops with numeric values")
714+
711715
else:
712716
if not (is_float(other) or is_integer(other)):
713717
raise TypeError("can only perform ops with numeric values")

pandas/tests/arrays/test_integer.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_arith_coerce_scalar(self, data, all_arithmetic_operators):
280280
other = 0.01
281281
self._check_op(s, op, other)
282282

283-
@pytest.mark.parametrize("other", [1.0, 1.0, np.array(1.0), np.array([1.0])])
283+
@pytest.mark.parametrize("other", [1.0, np.array(1.0)])
284284
def test_arithmetic_conversion(self, all_arithmetic_operators, other):
285285
# if we have a float operand we should have a float result
286286
# if that is equal to an integer
@@ -290,6 +290,15 @@ def test_arithmetic_conversion(self, all_arithmetic_operators, other):
290290
result = op(s, other)
291291
assert result.dtype is np.dtype("float")
292292

293+
def test_arith_len_mismatch(self, all_arithmetic_operators):
294+
# operating with a list-like with non-matching length raises
295+
op = self.get_op_from_name(all_arithmetic_operators)
296+
other = np.array([1.0])
297+
298+
s = pd.Series([1, 2, 3], dtype="Int64")
299+
with pytest.raises(ValueError, match="Lengths must match"):
300+
op(s, other)
301+
293302
@pytest.mark.parametrize("other", [0, 0.5])
294303
def test_arith_zero_dim_ndarray(self, other):
295304
arr = integer_array([1, None, 2])
@@ -322,8 +331,9 @@ def test_error(self, data, all_arithmetic_operators):
322331
ops(pd.Series(pd.date_range("20180101", periods=len(s))))
323332

324333
# 2d
325-
with pytest.raises(NotImplementedError):
326-
opa(pd.DataFrame({"A": s}))
334+
result = opa(pd.DataFrame({"A": s}))
335+
assert result is NotImplemented
336+
327337
with pytest.raises(NotImplementedError):
328338
opa(np.arange(len(s)).reshape(-1, len(s)))
329339

0 commit comments

Comments
 (0)