diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 9422434a1d998..ab940384594bc 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -160,9 +160,18 @@ your ``MyExtensionArray`` class, as follows: MyExtensionArray._add_arithmetic_ops() MyExtensionArray._add_comparison_ops() -Note that since ``pandas`` automatically calls the underlying operator on each -element one-by-one, this might not be as performant as implementing your own -version of the associated operators directly on the ``ExtensionArray``. + +.. note:: + + Since ``pandas`` automatically calls the underlying operator on each + element one-by-one, this might not be as performant as implementing your own + version of the associated operators directly on the ``ExtensionArray``. + +For arithmetic operations, this implementation will try to reconstruct a new +``ExtensionArray`` with the result of the element-wise operation. Whether +or not that succeeds depends on whether the operation returns a result +that's valid for the ``ExtensionArray``. If an ``ExtensionArray`` cannot +be reconstructed, an ndarray containing the scalars returned instead. .. _extending.extension.testing: diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index f7c4ee35adfe4..efe587c6aaaad 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -781,17 +781,24 @@ def convert_values(param): # a TypeError should be raised res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] - if coerce_to_dtype: - try: - res = self._from_sequence(res) - except Exception: + def _maybe_convert(arr): + if coerce_to_dtype: # https://github.com/pandas-dev/pandas/issues/22850 # We catch all regular exceptions here, and fall back # to an ndarray. - res = np.asarray(res) + try: + res = self._from_sequence(arr) + except Exception: + res = np.asarray(arr) + else: + res = np.asarray(arr) + return res + + if op.__name__ in {'divmod', 'rdivmod'}: + a, b = zip(*res) + res = _maybe_convert(a), _maybe_convert(b) else: - res = np.asarray(res) - + res = _maybe_convert(res) return res op_name = ops._get_op_name(op, True) diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index ee4a92146128b..36696bc292162 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -58,7 +58,8 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators): s = pd.Series(data) self.check_opname(s, op_name, s.iloc[0], exc=TypeError) - @pytest.mark.xfail(run=False, reason="_reduce needs implementation") + @pytest.mark.xfail(run=False, reason="_reduce needs implementation", + strict=True) def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): # frame & scalar op_name = all_arithmetic_operators @@ -77,6 +78,10 @@ def test_divmod(self, data): self._check_divmod_op(s, divmod, 1, exc=TypeError) self._check_divmod_op(1, ops.rdivmod, s, exc=TypeError) + def test_divmod_series_array(self, data): + s = pd.Series(data) + self._check_divmod_op(s, divmod, data) + def test_add_series_with_extension_array(self, data): s = pd.Series(data) result = s + data diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index dd625d6e1eb3c..6488c7724229b 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -8,7 +8,7 @@ from pandas.tests.extension import base -from .array import DecimalDtype, DecimalArray, make_data +from .array import DecimalDtype, DecimalArray, make_data, to_decimal @pytest.fixture @@ -102,7 +102,7 @@ class TestInterface(BaseDecimal, base.BaseInterfaceTests): class TestConstructors(BaseDecimal, base.BaseConstructorsTests): - @pytest.mark.xfail(reason="not implemented constructor from dtype") + @pytest.mark.skip(reason="not implemented constructor from dtype") def test_from_dtype(self, data): # construct from our dtype & string dtype pass @@ -240,9 +240,11 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators): context.traps[decimal.DivisionByZero] = divbyzerotrap context.traps[decimal.InvalidOperation] = invalidoptrap - @pytest.mark.skip(reason="divmod not appropriate for decimal") - def test_divmod(self, data): - pass + def _check_divmod_op(self, s, op, other, exc=NotImplementedError): + # We implement divmod + super(TestArithmeticOps, self)._check_divmod_op( + s, op, other, exc=None + ) def test_error(self): pass @@ -315,3 +317,21 @@ def test_scalar_ops_from_sequence_raises(class_): expected = np.array([decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object") tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("reverse, expected_div, expected_mod", [ + (False, [0, 1, 1, 2], [1, 0, 1, 0]), + (True, [2, 1, 0, 0], [0, 0, 2, 2]), +]) +def test_divmod_array(reverse, expected_div, expected_mod): + # https://github.com/pandas-dev/pandas/issues/22930 + arr = to_decimal([1, 2, 3, 4]) + if reverse: + div, mod = divmod(2, arr) + else: + div, mod = divmod(arr, 2) + expected_div = to_decimal(expected_div) + expected_mod = to_decimal(expected_mod) + + tm.assert_extension_array_equal(div, expected_div) + tm.assert_extension_array_equal(mod, expected_mod) diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index bcbc3e9109182..115afdcc99f2b 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -131,8 +131,7 @@ def test_custom_asserts(self): class TestConstructors(BaseJSON, base.BaseConstructorsTests): - # TODO: Should this be pytest.mark.skip? - @pytest.mark.xfail(reason="not implemented constructor from dtype") + @pytest.mark.skip(reason="not implemented constructor from dtype") def test_from_dtype(self, data): # construct from our dtype & string dtype pass @@ -147,13 +146,11 @@ class TestGetitem(BaseJSON, base.BaseGetitemTests): class TestMissing(BaseJSON, base.BaseMissingTests): - # TODO: Should this be pytest.mark.skip? - @pytest.mark.xfail(reason="Setting a dict as a scalar") + @pytest.mark.skip(reason="Setting a dict as a scalar") def test_fillna_series(self): """We treat dictionaries as a mapping in fillna, not a scalar.""" - # TODO: Should this be pytest.mark.skip? - @pytest.mark.xfail(reason="Setting a dict as a scalar") + @pytest.mark.skip(reason="Setting a dict as a scalar") def test_fillna_frame(self): """We treat dictionaries as a mapping in fillna, not a scalar.""" @@ -204,8 +201,7 @@ def test_combine_add(self, data_repeated): class TestCasting(BaseJSON, base.BaseCastingTests): - # TODO: Should this be pytest.mark.skip? - @pytest.mark.xfail(reason="failing on np.array(self, dtype=str)") + @pytest.mark.skip(reason="failing on np.array(self, dtype=str)") def test_astype_str(self): """This currently fails in NumPy on np.array(self, dtype=str) with @@ -257,6 +253,11 @@ def test_add_series_with_extension_array(self, data): with tm.assert_raises_regex(TypeError, "unsupported"): ser + data + def _check_divmod_op(self, s, op, other, exc=NotImplementedError): + return super(TestArithmeticOps, self)._check_divmod_op( + s, op, other, exc=TypeError + ) + class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests): pass diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index c588552572aed..f118279c4b915 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -140,11 +140,11 @@ def test_take_series(self): def test_reindex_non_na_fill_value(self): pass - @pytest.mark.xfail(reason="Categorical.take buggy") + @pytest.mark.skip(reason="Categorical.take buggy") def test_take_empty(self): pass - @pytest.mark.xfail(reason="test not written correctly for categorical") + @pytest.mark.skip(reason="test not written correctly for categorical") def test_reindex(self): pass @@ -208,6 +208,11 @@ def test_add_series_with_extension_array(self, data): with tm.assert_raises_regex(TypeError, "cannot perform"): ser + data + def _check_divmod_op(self, s, op, other, exc=NotImplementedError): + return super(TestArithmeticOps, self)._check_divmod_op( + s, op, other, exc=TypeError + ) + class TestComparisonOps(base.BaseComparisonOpsTests):