Skip to content

Commit ee65738

Browse files
jbrockmendelKevin D Smith
authored and
Kevin D Smith
committed
TYP: use OpsMixin for DecimalArray (pandas-dev#36961)
1 parent 76dc275 commit ee65738

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

pandas/tests/extension/decimal/array.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import numpy as np
88

99
from pandas.core.dtypes.base import ExtensionDtype
10-
from pandas.core.dtypes.common import is_dtype_equal, pandas_dtype
10+
from pandas.core.dtypes.common import is_dtype_equal, is_list_like, pandas_dtype
1111

1212
import pandas as pd
1313
from pandas.api.extensions import no_default, register_extension_dtype
14+
from pandas.core.arraylike import OpsMixin
1415
from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin
1516
from pandas.core.indexers import check_array_indexer
1617

@@ -44,7 +45,7 @@ def _is_numeric(self) -> bool:
4445
return True
4546

4647

47-
class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin):
48+
class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
4849
__array_priority__ = 1000
4950

5051
def __init__(self, values, dtype=None, copy=False, context=None):
@@ -197,6 +198,25 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
197198
) from err
198199
return op(axis=0)
199200

201+
def _cmp_method(self, other, op):
202+
# For use with OpsMixin
203+
def convert_values(param):
204+
if isinstance(param, ExtensionArray) or is_list_like(param):
205+
ovalues = param
206+
else:
207+
# Assume it's an object
208+
ovalues = [param] * len(self)
209+
return ovalues
210+
211+
lvalues = self
212+
rvalues = convert_values(other)
213+
214+
# If the operator is not defined for the underlying objects,
215+
# a TypeError should be raised
216+
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]
217+
218+
return np.asarray(res, dtype=bool)
219+
200220

201221
def to_decimal(values, context=None):
202222
return DecimalArray([decimal.Decimal(x) for x in values], context=context)
@@ -207,4 +227,3 @@ def make_data():
207227

208228

209229
DecimalArray._add_arithmetic_ops()
210-
DecimalArray._add_comparison_ops()

0 commit comments

Comments
 (0)