diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index dcd31abaa8857..0e9a51cf91d0f 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -406,6 +406,8 @@ Numeric - Bug in :meth:`DataFrame.rank` raising ``ValueError`` with ``object`` columns and ``method="first"`` (:issue:`41931`) - Bug in :meth:`DataFrame.rank` treating missing values and extreme values as equal (for example ``np.nan`` and ``np.inf``), causing incorrect results when ``na_option="bottom"`` or ``na_option="top`` used (:issue:`41931`) - Bug in ``numexpr`` engine still being used when the option ``compute.use_numexpr`` is set to ``False`` (:issue:`32556`) +- Bug in :class:`DataFrame` arithmetic ops with a subclass whose :meth:`_constructor` attribute is a callable other than the subclass itself (:issue:`43201`) +- Conversion ^^^^^^^^^^ diff --git a/pandas/core/frame.py b/pandas/core/frame.py index aad7213c93a1d..4d4fe3432d7c9 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -6954,7 +6954,7 @@ def _dispatch_frame_op(self, right, func: Callable, axis: int | None = None): # i.e. scalar, faster than checking np.ndim(right) == 0 with np.errstate(all="ignore"): bm = self._mgr.apply(array_op, right=right) - return type(self)(bm) + return self._constructor(bm) elif isinstance(right, DataFrame): assert self.index.equals(right.index) @@ -6975,7 +6975,7 @@ def _dispatch_frame_op(self, right, func: Callable, axis: int | None = None): right._mgr, # type: ignore[arg-type] array_op, ) - return type(self)(bm) + return self._constructor(bm) elif isinstance(right, Series) and axis == 1: # axis=1 means we want to operate row-by-row diff --git a/pandas/tests/frame/test_arithmetic.py b/pandas/tests/frame/test_arithmetic.py index afa9593807acc..1ddb18c218cc6 100644 --- a/pandas/tests/frame/test_arithmetic.py +++ b/pandas/tests/frame/test_arithmetic.py @@ -1,5 +1,6 @@ from collections import deque from datetime import datetime +import functools import operator import re @@ -1845,3 +1846,39 @@ def test_bool_frame_mult_float(): result = df * 1.0 expected = DataFrame(np.ones((2, 2)), list("ab"), list("cd")) tm.assert_frame_equal(result, expected) + + +def test_frame_op_subclass_nonclass_constructor(): + # GH#43201 subclass._constructor is a function, not the subclass itself + + class SubclassedSeries(Series): + @property + def _constructor(self): + return SubclassedSeries + + @property + def _constructor_expanddim(self): + return SubclassedDataFrame + + class SubclassedDataFrame(DataFrame): + _metadata = ["my_extra_data"] + + def __init__(self, my_extra_data, *args, **kwargs): + self.my_extra_data = my_extra_data + super().__init__(*args, **kwargs) + + @property + def _constructor(self): + return functools.partial(type(self), self.my_extra_data) + + @property + def _constructor_sliced(self): + return SubclassedSeries + + sdf = SubclassedDataFrame("some_data", {"A": [1, 2, 3], "B": [4, 5, 6]}) + result = sdf * 2 + expected = SubclassedDataFrame("some_data", {"A": [2, 4, 6], "B": [8, 10, 12]}) + tm.assert_frame_equal(result, expected) + + result = sdf + sdf + tm.assert_frame_equal(result, expected)