diff --git a/pandas/core/ops/__init__.py b/pandas/core/ops/__init__.py index 7229b3de4e9f0..843f12c20b07b 100644 --- a/pandas/core/ops/__init__.py +++ b/pandas/core/ops/__init__.py @@ -47,7 +47,7 @@ import pandas as pd from pandas._typing import ArrayLike -from pandas.core.construction import extract_array +from pandas.core.construction import array, extract_array from pandas.core.ops import missing from pandas.core.ops.docstrings import ( _arith_doc_FRAME, @@ -460,6 +460,33 @@ def masked_arith_op(x, y, op): # Dispatch logic +def should_extension_dispatch(left: ABCSeries, right: Any) -> bool: + """ + Identify cases where Series operation should use dispatch_to_extension_op. + + Parameters + ---------- + left : Series + right : object + + Returns + ------- + bool + """ + if ( + is_extension_array_dtype(left.dtype) + or is_datetime64_dtype(left.dtype) + or is_timedelta64_dtype(left.dtype) + ): + return True + + if is_extension_array_dtype(right) and not is_scalar(right): + # GH#22378 disallow scalar to exclude e.g. "category", "Int64" + return True + + return False + + def should_series_dispatch(left, right, op): """ Identify cases where a DataFrame operation should dispatch to its @@ -564,19 +591,18 @@ def dispatch_to_extension_op(op, left, right): apply the operator defined by op. """ + if left.dtype.kind in "mM": + # We need to cast datetime64 and timedelta64 ndarrays to + # DatetimeArray/TimedeltaArray. But we avoid wrapping others in + # PandasArray as that behaves poorly with e.g. IntegerArray. + left = array(left) + # The op calls will raise TypeError if the op is not defined # on the ExtensionArray # unbox Series and Index to arrays - if isinstance(left, (ABCSeries, ABCIndexClass)): - new_left = left._values - else: - new_left = left - - if isinstance(right, (ABCSeries, ABCIndexClass)): - new_right = right._values - else: - new_right = right + new_left = extract_array(left, extract_numpy=True) + new_right = extract_array(right, extract_numpy=True) try: res_values = op(new_left, new_right) @@ -684,56 +710,27 @@ def wrapper(left, right): res_name = get_op_result_name(left, right) right = maybe_upcast_for_op(right, left.shape) - if is_categorical_dtype(left): - raise TypeError( - "{typ} cannot perform the operation " - "{op}".format(typ=type(left).__name__, op=str_rep) - ) - - elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left): - from pandas.core.arrays import DatetimeArray - - result = dispatch_to_extension_op(op, DatetimeArray(left), right) - return construct_result(left, result, index=left.index, name=res_name) - - elif is_extension_array_dtype(left) or ( - is_extension_array_dtype(right) and not is_scalar(right) - ): - # GH#22378 disallow scalar to exclude e.g. "category", "Int64" + if should_extension_dispatch(left, right): result = dispatch_to_extension_op(op, left, right) - return construct_result(left, result, index=left.index, name=res_name) - elif is_timedelta64_dtype(left): - from pandas.core.arrays import TimedeltaArray - - result = dispatch_to_extension_op(op, TimedeltaArray(left), right) - return construct_result(left, result, index=left.index, name=res_name) - - elif is_timedelta64_dtype(right): - # We should only get here with non-scalar values for right - # upcast by maybe_upcast_for_op + elif is_timedelta64_dtype(right) or isinstance( + right, (ABCDatetimeArray, ABCDatetimeIndex) + ): + # We should only get here with td64 right with non-scalar values + # for right upcast by maybe_upcast_for_op assert not isinstance(right, (np.timedelta64, np.ndarray)) - result = op(left._values, right) - # We do not pass dtype to ensure that the Series constructor - # does inference in the case where `result` has object-dtype. - return construct_result(left, result, index=left.index, name=res_name) - - elif isinstance(right, (ABCDatetimeArray, ABCDatetimeIndex)): - result = op(left._values, right) - return construct_result(left, result, index=left.index, name=res_name) + else: + lvalues = extract_array(left, extract_numpy=True) + rvalues = extract_array(right, extract_numpy=True) - lvalues = left.values - rvalues = right - if isinstance(rvalues, (ABCSeries, ABCIndexClass)): - rvalues = rvalues._values + with np.errstate(all="ignore"): + result = na_op(lvalues, rvalues) - with np.errstate(all="ignore"): - result = na_op(lvalues, rvalues) - return construct_result( - left, result, index=left.index, name=res_name, dtype=None - ) + # We do not pass dtype to ensure that the Series constructor + # does inference in the case where `result` has object-dtype. + return construct_result(left, result, index=left.index, name=res_name) wrapper.__name__ = op_name return wrapper diff --git a/pandas/tests/arrays/categorical/test_operators.py b/pandas/tests/arrays/categorical/test_operators.py index 9a09ea8422b1f..22c1d5373372a 100644 --- a/pandas/tests/arrays/categorical/test_operators.py +++ b/pandas/tests/arrays/categorical/test_operators.py @@ -349,7 +349,9 @@ def test_numeric_like_ops(self): ("__mul__", r"\*"), ("__truediv__", "/"), ]: - msg = r"Series cannot perform the operation {}".format(str_rep) + msg = r"Series cannot perform the operation {}|unsupported operand".format( + str_rep + ) with pytest.raises(TypeError, match=msg): getattr(df, op)(df) @@ -375,7 +377,9 @@ def test_numeric_like_ops(self): ("__mul__", r"\*"), ("__truediv__", "/"), ]: - msg = r"Series cannot perform the operation {}".format(str_rep) + msg = r"Series cannot perform the operation {}|unsupported operand".format( + str_rep + ) with pytest.raises(TypeError, match=msg): getattr(s, op)(2) diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index f7456d24ad6d3..0c0e8b0123c03 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -211,7 +211,7 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators): def test_add_series_with_extension_array(self, data): ser = pd.Series(data) - with pytest.raises(TypeError, match="cannot perform"): + with pytest.raises(TypeError, match="cannot perform|unsupported operand"): ser + data def test_divmod_series_array(self):