Skip to content

Commit 397e36c

Browse files
authored
TYP/REF: use _cmp_method in EAs (#36954)
1 parent 3a83b2c commit 397e36c

File tree

5 files changed

+166
-191
lines changed

5 files changed

+166
-191
lines changed

pandas/core/arrays/boolean.py

+33-41
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pandas.core.dtypes.missing import isna
2424

2525
from pandas.core import ops
26+
from pandas.core.arraylike import OpsMixin
2627

2728
from .masked import BaseMaskedArray, BaseMaskedDtype
2829

@@ -202,7 +203,7 @@ def coerce_to_array(
202203
return values, mask
203204

204205

205-
class BooleanArray(BaseMaskedArray):
206+
class BooleanArray(OpsMixin, BaseMaskedArray):
206207
"""
207208
Array of boolean (True/False) data with missing values.
208209
@@ -603,52 +604,44 @@ def logical_method(self, other):
603604
name = f"__{op.__name__}__"
604605
return set_function_name(logical_method, name, cls)
605606

606-
@classmethod
607-
def _create_comparison_method(cls, op):
608-
@ops.unpack_zerodim_and_defer(op.__name__)
609-
def cmp_method(self, other):
610-
from pandas.arrays import FloatingArray, IntegerArray
607+
def _cmp_method(self, other, op):
608+
from pandas.arrays import FloatingArray, IntegerArray
611609

612-
if isinstance(other, (IntegerArray, FloatingArray)):
613-
return NotImplemented
610+
if isinstance(other, (IntegerArray, FloatingArray)):
611+
return NotImplemented
614612

615-
mask = None
613+
mask = None
616614

617-
if isinstance(other, BooleanArray):
618-
other, mask = other._data, other._mask
615+
if isinstance(other, BooleanArray):
616+
other, mask = other._data, other._mask
619617

620-
elif is_list_like(other):
621-
other = np.asarray(other)
622-
if other.ndim > 1:
623-
raise NotImplementedError(
624-
"can only perform ops with 1-d structures"
625-
)
626-
if len(self) != len(other):
627-
raise ValueError("Lengths must match to compare")
618+
elif is_list_like(other):
619+
other = np.asarray(other)
620+
if other.ndim > 1:
621+
raise NotImplementedError("can only perform ops with 1-d structures")
622+
if len(self) != len(other):
623+
raise ValueError("Lengths must match to compare")
628624

629-
if other is libmissing.NA:
630-
# numpy does not handle pd.NA well as "other" scalar (it returns
631-
# a scalar False instead of an array)
632-
result = np.zeros_like(self._data)
633-
mask = np.ones_like(self._data)
634-
else:
635-
# numpy will show a DeprecationWarning on invalid elementwise
636-
# comparisons, this will raise in the future
637-
with warnings.catch_warnings():
638-
warnings.filterwarnings("ignore", "elementwise", FutureWarning)
639-
with np.errstate(all="ignore"):
640-
result = op(self._data, other)
641-
642-
# nans propagate
643-
if mask is None:
644-
mask = self._mask.copy()
645-
else:
646-
mask = self._mask | mask
625+
if other is libmissing.NA:
626+
# numpy does not handle pd.NA well as "other" scalar (it returns
627+
# a scalar False instead of an array)
628+
result = np.zeros_like(self._data)
629+
mask = np.ones_like(self._data)
630+
else:
631+
# numpy will show a DeprecationWarning on invalid elementwise
632+
# comparisons, this will raise in the future
633+
with warnings.catch_warnings():
634+
warnings.filterwarnings("ignore", "elementwise", FutureWarning)
635+
with np.errstate(all="ignore"):
636+
result = op(self._data, other)
647637

648-
return BooleanArray(result, mask, copy=False)
638+
# nans propagate
639+
if mask is None:
640+
mask = self._mask.copy()
641+
else:
642+
mask = self._mask | mask
649643

650-
name = f"__{op.__name__}"
651-
return set_function_name(cmp_method, name, cls)
644+
return BooleanArray(result, mask, copy=False)
652645

653646
def _reduce(self, name: str, skipna: bool = True, **kwargs):
654647

@@ -741,5 +734,4 @@ def boolean_arithmetic_method(self, other):
741734

742735

743736
BooleanArray._add_logical_ops()
744-
BooleanArray._add_comparison_ops()
745737
BooleanArray._add_arithmetic_ops()

pandas/core/arrays/floating.py

+37-47
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pandas.core.dtypes.missing import isna
2727

2828
from pandas.core import ops
29+
from pandas.core.arraylike import OpsMixin
2930
from pandas.core.ops import invalid_comparison
3031
from pandas.core.ops.common import unpack_zerodim_and_defer
3132
from pandas.core.tools.numeric import to_numeric
@@ -201,7 +202,7 @@ def coerce_to_array(
201202
return values, mask
202203

203204

204-
class FloatingArray(BaseMaskedArray):
205+
class FloatingArray(OpsMixin, BaseMaskedArray):
205206
"""
206207
Array of floating (optional missing) values.
207208
@@ -398,58 +399,48 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike:
398399
def _values_for_argsort(self) -> np.ndarray:
399400
return self._data
400401

401-
@classmethod
402-
def _create_comparison_method(cls, op):
403-
op_name = op.__name__
402+
def _cmp_method(self, other, op):
403+
from pandas.arrays import BooleanArray, IntegerArray
404404

405-
@unpack_zerodim_and_defer(op.__name__)
406-
def cmp_method(self, other):
407-
from pandas.arrays import BooleanArray, IntegerArray
405+
mask = None
408406

409-
mask = None
407+
if isinstance(other, (BooleanArray, IntegerArray, FloatingArray)):
408+
other, mask = other._data, other._mask
410409

411-
if isinstance(other, (BooleanArray, IntegerArray, FloatingArray)):
412-
other, mask = other._data, other._mask
410+
elif is_list_like(other):
411+
other = np.asarray(other)
412+
if other.ndim > 1:
413+
raise NotImplementedError("can only perform ops with 1-d structures")
413414

414-
elif is_list_like(other):
415-
other = np.asarray(other)
416-
if other.ndim > 1:
417-
raise NotImplementedError(
418-
"can only perform ops with 1-d structures"
419-
)
415+
if other is libmissing.NA:
416+
# numpy does not handle pd.NA well as "other" scalar (it returns
417+
# a scalar False instead of an array)
418+
# This may be fixed by NA.__array_ufunc__. Revisit this check
419+
# once that's implemented.
420+
result = np.zeros(self._data.shape, dtype="bool")
421+
mask = np.ones(self._data.shape, dtype="bool")
422+
else:
423+
with warnings.catch_warnings():
424+
# numpy may show a FutureWarning:
425+
# elementwise comparison failed; returning scalar instead,
426+
# but in the future will perform elementwise comparison
427+
# before returning NotImplemented. We fall back to the correct
428+
# behavior today, so that should be fine to ignore.
429+
warnings.filterwarnings("ignore", "elementwise", FutureWarning)
430+
with np.errstate(all="ignore"):
431+
method = getattr(self._data, f"__{op.__name__}__")
432+
result = method(other)
420433

421-
if other is libmissing.NA:
422-
# numpy does not handle pd.NA well as "other" scalar (it returns
423-
# a scalar False instead of an array)
424-
# This may be fixed by NA.__array_ufunc__. Revisit this check
425-
# once that's implemented.
426-
result = np.zeros(self._data.shape, dtype="bool")
427-
mask = np.ones(self._data.shape, dtype="bool")
428-
else:
429-
with warnings.catch_warnings():
430-
# numpy may show a FutureWarning:
431-
# elementwise comparison failed; returning scalar instead,
432-
# but in the future will perform elementwise comparison
433-
# before returning NotImplemented. We fall back to the correct
434-
# behavior today, so that should be fine to ignore.
435-
warnings.filterwarnings("ignore", "elementwise", FutureWarning)
436-
with np.errstate(all="ignore"):
437-
method = getattr(self._data, f"__{op_name}__")
438-
result = method(other)
439-
440-
if result is NotImplemented:
441-
result = invalid_comparison(self._data, other, op)
442-
443-
# nans propagate
444-
if mask is None:
445-
mask = self._mask.copy()
446-
else:
447-
mask = self._mask | mask
434+
if result is NotImplemented:
435+
result = invalid_comparison(self._data, other, op)
448436

449-
return BooleanArray(result, mask)
437+
# nans propagate
438+
if mask is None:
439+
mask = self._mask.copy()
440+
else:
441+
mask = self._mask | mask
450442

451-
name = f"__{op.__name__}__"
452-
return set_function_name(cmp_method, name, cls)
443+
return BooleanArray(result, mask)
453444

454445
def sum(self, skipna=True, min_count=0, **kwargs):
455446
nv.validate_sum((), kwargs)
@@ -565,7 +556,6 @@ def floating_arithmetic_method(self, other):
565556

566557

567558
FloatingArray._add_arithmetic_ops()
568-
FloatingArray._add_comparison_ops()
569559

570560

571561
_dtype_docstring = """

pandas/core/arrays/integer.py

+40-50
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pandas.core.dtypes.missing import isna
2727

2828
from pandas.core import ops
29+
from pandas.core.arraylike import OpsMixin
2930
from pandas.core.ops import invalid_comparison
3031
from pandas.core.ops.common import unpack_zerodim_and_defer
3132
from pandas.core.tools.numeric import to_numeric
@@ -265,7 +266,7 @@ def coerce_to_array(
265266
return values, mask
266267

267268

268-
class IntegerArray(BaseMaskedArray):
269+
class IntegerArray(OpsMixin, BaseMaskedArray):
269270
"""
270271
Array of integer (optional missing) values.
271272
@@ -493,60 +494,50 @@ def _values_for_argsort(self) -> np.ndarray:
493494
data[self._mask] = data.min() - 1
494495
return data
495496

496-
@classmethod
497-
def _create_comparison_method(cls, op):
498-
op_name = op.__name__
497+
def _cmp_method(self, other, op):
498+
from pandas.core.arrays import BaseMaskedArray, BooleanArray
499499

500-
@unpack_zerodim_and_defer(op.__name__)
501-
def cmp_method(self, other):
502-
from pandas.core.arrays import BaseMaskedArray, BooleanArray
500+
mask = None
503501

504-
mask = None
502+
if isinstance(other, BaseMaskedArray):
503+
other, mask = other._data, other._mask
505504

506-
if isinstance(other, BaseMaskedArray):
507-
other, mask = other._data, other._mask
508-
509-
elif is_list_like(other):
510-
other = np.asarray(other)
511-
if other.ndim > 1:
512-
raise NotImplementedError(
513-
"can only perform ops with 1-d structures"
514-
)
515-
if len(self) != len(other):
516-
raise ValueError("Lengths must match to compare")
505+
elif is_list_like(other):
506+
other = np.asarray(other)
507+
if other.ndim > 1:
508+
raise NotImplementedError("can only perform ops with 1-d structures")
509+
if len(self) != len(other):
510+
raise ValueError("Lengths must match to compare")
511+
512+
if other is libmissing.NA:
513+
# numpy does not handle pd.NA well as "other" scalar (it returns
514+
# a scalar False instead of an array)
515+
# This may be fixed by NA.__array_ufunc__. Revisit this check
516+
# once that's implemented.
517+
result = np.zeros(self._data.shape, dtype="bool")
518+
mask = np.ones(self._data.shape, dtype="bool")
519+
else:
520+
with warnings.catch_warnings():
521+
# numpy may show a FutureWarning:
522+
# elementwise comparison failed; returning scalar instead,
523+
# but in the future will perform elementwise comparison
524+
# before returning NotImplemented. We fall back to the correct
525+
# behavior today, so that should be fine to ignore.
526+
warnings.filterwarnings("ignore", "elementwise", FutureWarning)
527+
with np.errstate(all="ignore"):
528+
method = getattr(self._data, f"__{op.__name__}__")
529+
result = method(other)
517530

518-
if other is libmissing.NA:
519-
# numpy does not handle pd.NA well as "other" scalar (it returns
520-
# a scalar False instead of an array)
521-
# This may be fixed by NA.__array_ufunc__. Revisit this check
522-
# once that's implemented.
523-
result = np.zeros(self._data.shape, dtype="bool")
524-
mask = np.ones(self._data.shape, dtype="bool")
525-
else:
526-
with warnings.catch_warnings():
527-
# numpy may show a FutureWarning:
528-
# elementwise comparison failed; returning scalar instead,
529-
# but in the future will perform elementwise comparison
530-
# before returning NotImplemented. We fall back to the correct
531-
# behavior today, so that should be fine to ignore.
532-
warnings.filterwarnings("ignore", "elementwise", FutureWarning)
533-
with np.errstate(all="ignore"):
534-
method = getattr(self._data, f"__{op_name}__")
535-
result = method(other)
536-
537-
if result is NotImplemented:
538-
result = invalid_comparison(self._data, other, op)
539-
540-
# nans propagate
541-
if mask is None:
542-
mask = self._mask.copy()
543-
else:
544-
mask = self._mask | mask
531+
if result is NotImplemented:
532+
result = invalid_comparison(self._data, other, op)
545533

546-
return BooleanArray(result, mask)
534+
# nans propagate
535+
if mask is None:
536+
mask = self._mask.copy()
537+
else:
538+
mask = self._mask | mask
547539

548-
name = f"__{op.__name__}__"
549-
return set_function_name(cmp_method, name, cls)
540+
return BooleanArray(result, mask)
550541

551542
def sum(self, skipna=True, min_count=0, **kwargs):
552543
nv.validate_sum((), kwargs)
@@ -669,7 +660,6 @@ def integer_arithmetic_method(self, other):
669660

670661

671662
IntegerArray._add_arithmetic_ops()
672-
IntegerArray._add_comparison_ops()
673663

674664

675665
_dtype_docstring = """

0 commit comments

Comments
 (0)