Skip to content

Commit e510b1a

Browse files
TomAugspurgerjorisvandenbossche
authored andcommitted
BUG: divmod return type (#22932)
1 parent a389a53 commit e510b1a

File tree

6 files changed

+73
-26
lines changed

6 files changed

+73
-26
lines changed

doc/source/extending.rst

+12-3
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,18 @@ your ``MyExtensionArray`` class, as follows:
160160
MyExtensionArray._add_arithmetic_ops()
161161
MyExtensionArray._add_comparison_ops()
162162
163-
Note that since ``pandas`` automatically calls the underlying operator on each
164-
element one-by-one, this might not be as performant as implementing your own
165-
version of the associated operators directly on the ``ExtensionArray``.
163+
164+
.. note::
165+
166+
Since ``pandas`` automatically calls the underlying operator on each
167+
element one-by-one, this might not be as performant as implementing your own
168+
version of the associated operators directly on the ``ExtensionArray``.
169+
170+
For arithmetic operations, this implementation will try to reconstruct a new
171+
``ExtensionArray`` with the result of the element-wise operation. Whether
172+
or not that succeeds depends on whether the operation returns a result
173+
that's valid for the ``ExtensionArray``. If an ``ExtensionArray`` cannot
174+
be reconstructed, an ndarray containing the scalars returned instead.
166175

167176
.. _extending.extension.testing:
168177

pandas/core/arrays/base.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -781,17 +781,24 @@ def convert_values(param):
781781
# a TypeError should be raised
782782
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]
783783

784-
if coerce_to_dtype:
785-
try:
786-
res = self._from_sequence(res)
787-
except Exception:
784+
def _maybe_convert(arr):
785+
if coerce_to_dtype:
788786
# https://github.com/pandas-dev/pandas/issues/22850
789787
# We catch all regular exceptions here, and fall back
790788
# to an ndarray.
791-
res = np.asarray(res)
789+
try:
790+
res = self._from_sequence(arr)
791+
except Exception:
792+
res = np.asarray(arr)
793+
else:
794+
res = np.asarray(arr)
795+
return res
796+
797+
if op.__name__ in {'divmod', 'rdivmod'}:
798+
a, b = zip(*res)
799+
res = _maybe_convert(a), _maybe_convert(b)
792800
else:
793-
res = np.asarray(res)
794-
801+
res = _maybe_convert(res)
795802
return res
796803

797804
op_name = ops._get_op_name(op, True)

pandas/tests/extension/base/ops.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
5858
s = pd.Series(data)
5959
self.check_opname(s, op_name, s.iloc[0], exc=TypeError)
6060

61-
@pytest.mark.xfail(run=False, reason="_reduce needs implementation")
61+
@pytest.mark.xfail(run=False, reason="_reduce needs implementation",
62+
strict=True)
6263
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
6364
# frame & scalar
6465
op_name = all_arithmetic_operators
@@ -77,6 +78,10 @@ def test_divmod(self, data):
7778
self._check_divmod_op(s, divmod, 1, exc=TypeError)
7879
self._check_divmod_op(1, ops.rdivmod, s, exc=TypeError)
7980

81+
def test_divmod_series_array(self, data):
82+
s = pd.Series(data)
83+
self._check_divmod_op(s, divmod, data)
84+
8085
def test_add_series_with_extension_array(self, data):
8186
s = pd.Series(data)
8287
result = s + data

pandas/tests/extension/decimal/test_decimal.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pandas.tests.extension import base
1010

11-
from .array import DecimalDtype, DecimalArray, make_data
11+
from .array import DecimalDtype, DecimalArray, make_data, to_decimal
1212

1313

1414
@pytest.fixture
@@ -102,7 +102,7 @@ class TestInterface(BaseDecimal, base.BaseInterfaceTests):
102102

103103
class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
104104

105-
@pytest.mark.xfail(reason="not implemented constructor from dtype")
105+
@pytest.mark.skip(reason="not implemented constructor from dtype")
106106
def test_from_dtype(self, data):
107107
# construct from our dtype & string dtype
108108
pass
@@ -240,9 +240,11 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators):
240240
context.traps[decimal.DivisionByZero] = divbyzerotrap
241241
context.traps[decimal.InvalidOperation] = invalidoptrap
242242

243-
@pytest.mark.skip(reason="divmod not appropriate for decimal")
244-
def test_divmod(self, data):
245-
pass
243+
def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
244+
# We implement divmod
245+
super(TestArithmeticOps, self)._check_divmod_op(
246+
s, op, other, exc=None
247+
)
246248

247249
def test_error(self):
248250
pass
@@ -315,3 +317,21 @@ def test_scalar_ops_from_sequence_raises(class_):
315317
expected = np.array([decimal.Decimal("2.0"), decimal.Decimal("4.0")],
316318
dtype="object")
317319
tm.assert_numpy_array_equal(result, expected)
320+
321+
322+
@pytest.mark.parametrize("reverse, expected_div, expected_mod", [
323+
(False, [0, 1, 1, 2], [1, 0, 1, 0]),
324+
(True, [2, 1, 0, 0], [0, 0, 2, 2]),
325+
])
326+
def test_divmod_array(reverse, expected_div, expected_mod):
327+
# https://github.com/pandas-dev/pandas/issues/22930
328+
arr = to_decimal([1, 2, 3, 4])
329+
if reverse:
330+
div, mod = divmod(2, arr)
331+
else:
332+
div, mod = divmod(arr, 2)
333+
expected_div = to_decimal(expected_div)
334+
expected_mod = to_decimal(expected_mod)
335+
336+
tm.assert_extension_array_equal(div, expected_div)
337+
tm.assert_extension_array_equal(mod, expected_mod)

pandas/tests/extension/json/test_json.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ def test_custom_asserts(self):
131131

132132
class TestConstructors(BaseJSON, base.BaseConstructorsTests):
133133

134-
# TODO: Should this be pytest.mark.skip?
135-
@pytest.mark.xfail(reason="not implemented constructor from dtype")
134+
@pytest.mark.skip(reason="not implemented constructor from dtype")
136135
def test_from_dtype(self, data):
137136
# construct from our dtype & string dtype
138137
pass
@@ -147,13 +146,11 @@ class TestGetitem(BaseJSON, base.BaseGetitemTests):
147146

148147

149148
class TestMissing(BaseJSON, base.BaseMissingTests):
150-
# TODO: Should this be pytest.mark.skip?
151-
@pytest.mark.xfail(reason="Setting a dict as a scalar")
149+
@pytest.mark.skip(reason="Setting a dict as a scalar")
152150
def test_fillna_series(self):
153151
"""We treat dictionaries as a mapping in fillna, not a scalar."""
154152

155-
# TODO: Should this be pytest.mark.skip?
156-
@pytest.mark.xfail(reason="Setting a dict as a scalar")
153+
@pytest.mark.skip(reason="Setting a dict as a scalar")
157154
def test_fillna_frame(self):
158155
"""We treat dictionaries as a mapping in fillna, not a scalar."""
159156

@@ -204,8 +201,7 @@ def test_combine_add(self, data_repeated):
204201

205202

206203
class TestCasting(BaseJSON, base.BaseCastingTests):
207-
# TODO: Should this be pytest.mark.skip?
208-
@pytest.mark.xfail(reason="failing on np.array(self, dtype=str)")
204+
@pytest.mark.skip(reason="failing on np.array(self, dtype=str)")
209205
def test_astype_str(self):
210206
"""This currently fails in NumPy on np.array(self, dtype=str) with
211207
@@ -257,6 +253,11 @@ def test_add_series_with_extension_array(self, data):
257253
with tm.assert_raises_regex(TypeError, "unsupported"):
258254
ser + data
259255

256+
def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
257+
return super(TestArithmeticOps, self)._check_divmod_op(
258+
s, op, other, exc=TypeError
259+
)
260+
260261

261262
class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
262263
pass

pandas/tests/extension/test_categorical.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ def test_take_series(self):
140140
def test_reindex_non_na_fill_value(self):
141141
pass
142142

143-
@pytest.mark.xfail(reason="Categorical.take buggy")
143+
@pytest.mark.skip(reason="Categorical.take buggy")
144144
def test_take_empty(self):
145145
pass
146146

147-
@pytest.mark.xfail(reason="test not written correctly for categorical")
147+
@pytest.mark.skip(reason="test not written correctly for categorical")
148148
def test_reindex(self):
149149
pass
150150

@@ -208,6 +208,11 @@ def test_add_series_with_extension_array(self, data):
208208
with tm.assert_raises_regex(TypeError, "cannot perform"):
209209
ser + data
210210

211+
def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
212+
return super(TestArithmeticOps, self)._check_divmod_op(
213+
s, op, other, exc=TypeError
214+
)
215+
211216

212217
class TestComparisonOps(base.BaseComparisonOpsTests):
213218

0 commit comments

Comments
 (0)