Skip to content

REF: implement should_extension_dispatch #27815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 14, 2019
103 changes: 50 additions & 53 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

import pandas as pd
from pandas._typing import ArrayLike
from pandas.core.construction import extract_array
from pandas.core.construction import array, extract_array
from pandas.core.ops import missing
from pandas.core.ops.docstrings import (
_arith_doc_FRAME,
Expand Down Expand Up @@ -460,6 +460,33 @@ def masked_arith_op(x, y, op):
# Dispatch logic


def should_extension_dispatch(left: ABCSeries, right: Any) -> bool:
"""
Identify cases where Series operation should use dispatch_to_extension_op.

Parameters
----------
left : Series
right : object

Returns
-------
bool
"""
if (
is_extension_array_dtype(left.dtype)
or is_datetime64_dtype(left.dtype)
or is_timedelta64_dtype(left.dtype)
):
return True

if is_extension_array_dtype(right) and not is_scalar(right):
# GH#22378 disallow scalar to exclude e.g. "category", "Int64"
return True

return False


def should_series_dispatch(left, right, op):
"""
Identify cases where a DataFrame operation should dispatch to its
Expand Down Expand Up @@ -564,19 +591,18 @@ def dispatch_to_extension_op(op, left, right):
apply the operator defined by op.
"""

if left.dtype.kind in "mM":
# We need to cast datetime64 and timedelta64 ndarrays to
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in
# PandasArray as that behaves poorly with e.g. IntegerArray.
left = array(left)

# The op calls will raise TypeError if the op is not defined
# on the ExtensionArray

# unbox Series and Index to arrays
if isinstance(left, (ABCSeries, ABCIndexClass)):
new_left = left._values
else:
new_left = left

if isinstance(right, (ABCSeries, ABCIndexClass)):
new_right = right._values
else:
new_right = right
new_left = extract_array(left, extract_numpy=True)
new_right = extract_array(right, extract_numpy=True)

try:
res_values = op(new_left, new_right)
Expand Down Expand Up @@ -684,56 +710,27 @@ def wrapper(left, right):
res_name = get_op_result_name(left, right)
right = maybe_upcast_for_op(right, left.shape)

if is_categorical_dtype(left):
raise TypeError(
"{typ} cannot perform the operation "
"{op}".format(typ=type(left).__name__, op=str_rep)
)

elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
from pandas.core.arrays import DatetimeArray

result = dispatch_to_extension_op(op, DatetimeArray(left), right)
return construct_result(left, result, index=left.index, name=res_name)

elif is_extension_array_dtype(left) or (
is_extension_array_dtype(right) and not is_scalar(right)
):
# GH#22378 disallow scalar to exclude e.g. "category", "Int64"
if should_extension_dispatch(left, right):
result = dispatch_to_extension_op(op, left, right)
return construct_result(left, result, index=left.index, name=res_name)

elif is_timedelta64_dtype(left):
from pandas.core.arrays import TimedeltaArray

result = dispatch_to_extension_op(op, TimedeltaArray(left), right)
return construct_result(left, result, index=left.index, name=res_name)

elif is_timedelta64_dtype(right):
# We should only get here with non-scalar values for right
# upcast by maybe_upcast_for_op
elif is_timedelta64_dtype(right) or isinstance(
right, (ABCDatetimeArray, ABCDatetimeIndex)
):
# We should only get here with td64 right with non-scalar values
# for right upcast by maybe_upcast_for_op
assert not isinstance(right, (np.timedelta64, np.ndarray))

result = op(left._values, right)

# We do not pass dtype to ensure that the Series constructor
# does inference in the case where `result` has object-dtype.
return construct_result(left, result, index=left.index, name=res_name)

elif isinstance(right, (ABCDatetimeArray, ABCDatetimeIndex)):
result = op(left._values, right)
return construct_result(left, result, index=left.index, name=res_name)
else:
lvalues = extract_array(left, extract_numpy=True)
rvalues = extract_array(right, extract_numpy=True)

lvalues = left.values
rvalues = right
if isinstance(rvalues, (ABCSeries, ABCIndexClass)):
rvalues = rvalues._values
with np.errstate(all="ignore"):
result = na_op(lvalues, rvalues)

with np.errstate(all="ignore"):
result = na_op(lvalues, rvalues)
return construct_result(
left, result, index=left.index, name=res_name, dtype=None
)
# We do not pass dtype to ensure that the Series constructor
# does inference in the case where `result` has object-dtype.
return construct_result(left, result, index=left.index, name=res_name)

wrapper.__name__ = op_name
return wrapper
Expand Down
8 changes: 6 additions & 2 deletions pandas/tests/arrays/categorical/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,9 @@ def test_numeric_like_ops(self):
("__mul__", r"\*"),
("__truediv__", "/"),
]:
msg = r"Series cannot perform the operation {}".format(str_rep)
msg = r"Series cannot perform the operation {}|unsupported operand".format(
str_rep
)
with pytest.raises(TypeError, match=msg):
getattr(df, op)(df)

Expand All @@ -375,7 +377,9 @@ def test_numeric_like_ops(self):
("__mul__", r"\*"),
("__truediv__", "/"),
]:
msg = r"Series cannot perform the operation {}".format(str_rep)
msg = r"Series cannot perform the operation {}|unsupported operand".format(
str_rep
)
with pytest.raises(TypeError, match=msg):
getattr(s, op)(2)

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):

def test_add_series_with_extension_array(self, data):
ser = pd.Series(data)
with pytest.raises(TypeError, match="cannot perform"):
with pytest.raises(TypeError, match="cannot perform|unsupported operand"):
ser + data

def test_divmod_series_array(self):
Expand Down