7
7
import numpy as np
8
8
9
9
from pandas .core .dtypes .base import ExtensionDtype
10
- from pandas .core .dtypes .common import is_dtype_equal , pandas_dtype
10
+ from pandas .core .dtypes .common import is_dtype_equal , is_list_like , pandas_dtype
11
11
12
12
import pandas as pd
13
13
from pandas .api .extensions import no_default , register_extension_dtype
14
+ from pandas .core .arraylike import OpsMixin
14
15
from pandas .core .arrays import ExtensionArray , ExtensionScalarOpsMixin
15
16
from pandas .core .indexers import check_array_indexer
16
17
@@ -44,7 +45,7 @@ def _is_numeric(self) -> bool:
44
45
return True
45
46
46
47
47
- class DecimalArray (ExtensionArray , ExtensionScalarOpsMixin ):
48
+ class DecimalArray (OpsMixin , ExtensionScalarOpsMixin , ExtensionArray ):
48
49
__array_priority__ = 1000
49
50
50
51
def __init__ (self , values , dtype = None , copy = False , context = None ):
@@ -197,6 +198,25 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
197
198
) from err
198
199
return op (axis = 0 )
199
200
201
+ def _cmp_method (self , other , op ):
202
+ # For use with OpsMixin
203
+ def convert_values (param ):
204
+ if isinstance (param , ExtensionArray ) or is_list_like (param ):
205
+ ovalues = param
206
+ else :
207
+ # Assume it's an object
208
+ ovalues = [param ] * len (self )
209
+ return ovalues
210
+
211
+ lvalues = self
212
+ rvalues = convert_values (other )
213
+
214
+ # If the operator is not defined for the underlying objects,
215
+ # a TypeError should be raised
216
+ res = [op (a , b ) for (a , b ) in zip (lvalues , rvalues )]
217
+
218
+ return np .asarray (res , dtype = bool )
219
+
200
220
201
221
def to_decimal (values , context = None ):
202
222
return DecimalArray ([decimal .Decimal (x ) for x in values ], context = context )
@@ -207,4 +227,3 @@ def make_data():
207
227
208
228
209
229
DecimalArray ._add_arithmetic_ops ()
210
- DecimalArray ._add_comparison_ops ()
0 commit comments