From 2ecbfd7eaf18e06cacb27f3cb4960f1094f8fc7b Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 7 Oct 2020 15:49:55 -0700 Subject: [PATCH] TYP: use OpsMixin for DecimalArray --- pandas/tests/extension/decimal/array.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 2895f33d5c887..3d1ebb01d632f 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -7,10 +7,11 @@ import numpy as np from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.dtypes.common import is_dtype_equal, pandas_dtype +from pandas.core.dtypes.common import is_dtype_equal, is_list_like, pandas_dtype import pandas as pd from pandas.api.extensions import no_default, register_extension_dtype +from pandas.core.arraylike import OpsMixin from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin from pandas.core.indexers import check_array_indexer @@ -44,7 +45,7 @@ def _is_numeric(self) -> bool: return True -class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin): +class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray): __array_priority__ = 1000 def __init__(self, values, dtype=None, copy=False, context=None): @@ -197,6 +198,25 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs): ) from err return op(axis=0) + def _cmp_method(self, other, op): + # For use with OpsMixin + def convert_values(param): + if isinstance(param, ExtensionArray) or is_list_like(param): + ovalues = param + else: + # Assume it's an object + ovalues = [param] * len(self) + return ovalues + + lvalues = self + rvalues = convert_values(other) + + # If the operator is not defined for the underlying objects, + # a TypeError should be raised + res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] + + return np.asarray(res, dtype=bool) + def to_decimal(values, context=None): return DecimalArray([decimal.Decimal(x) for x in values], context=context) @@ -207,4 +227,3 @@ def make_data(): DecimalArray._add_arithmetic_ops() -DecimalArray._add_comparison_ops()