Skip to content

Commit ae319d2

Browse files
jbrockmendelquintusdias
authored andcommitted
REF: implement should_extension_dispatch (pandas-dev#27815)
1 parent ba511f1 commit ae319d2

File tree

3 files changed

+57
-56
lines changed

3 files changed

+57
-56
lines changed

pandas/core/ops/__init__.py

+50-53
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
import pandas as pd
4949
from pandas._typing import ArrayLike
50-
from pandas.core.construction import extract_array
50+
from pandas.core.construction import array, extract_array
5151
from pandas.core.ops import missing
5252
from pandas.core.ops.docstrings import (
5353
_arith_doc_FRAME,
@@ -460,6 +460,33 @@ def masked_arith_op(x, y, op):
460460
# Dispatch logic
461461

462462

463+
def should_extension_dispatch(left: ABCSeries, right: Any) -> bool:
464+
"""
465+
Identify cases where Series operation should use dispatch_to_extension_op.
466+
467+
Parameters
468+
----------
469+
left : Series
470+
right : object
471+
472+
Returns
473+
-------
474+
bool
475+
"""
476+
if (
477+
is_extension_array_dtype(left.dtype)
478+
or is_datetime64_dtype(left.dtype)
479+
or is_timedelta64_dtype(left.dtype)
480+
):
481+
return True
482+
483+
if is_extension_array_dtype(right) and not is_scalar(right):
484+
# GH#22378 disallow scalar to exclude e.g. "category", "Int64"
485+
return True
486+
487+
return False
488+
489+
463490
def should_series_dispatch(left, right, op):
464491
"""
465492
Identify cases where a DataFrame operation should dispatch to its
@@ -564,19 +591,18 @@ def dispatch_to_extension_op(op, left, right):
564591
apply the operator defined by op.
565592
"""
566593

594+
if left.dtype.kind in "mM":
595+
# We need to cast datetime64 and timedelta64 ndarrays to
596+
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in
597+
# PandasArray as that behaves poorly with e.g. IntegerArray.
598+
left = array(left)
599+
567600
# The op calls will raise TypeError if the op is not defined
568601
# on the ExtensionArray
569602

570603
# unbox Series and Index to arrays
571-
if isinstance(left, (ABCSeries, ABCIndexClass)):
572-
new_left = left._values
573-
else:
574-
new_left = left
575-
576-
if isinstance(right, (ABCSeries, ABCIndexClass)):
577-
new_right = right._values
578-
else:
579-
new_right = right
604+
new_left = extract_array(left, extract_numpy=True)
605+
new_right = extract_array(right, extract_numpy=True)
580606

581607
try:
582608
res_values = op(new_left, new_right)
@@ -684,56 +710,27 @@ def wrapper(left, right):
684710
res_name = get_op_result_name(left, right)
685711
right = maybe_upcast_for_op(right, left.shape)
686712

687-
if is_categorical_dtype(left):
688-
raise TypeError(
689-
"{typ} cannot perform the operation "
690-
"{op}".format(typ=type(left).__name__, op=str_rep)
691-
)
692-
693-
elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
694-
from pandas.core.arrays import DatetimeArray
695-
696-
result = dispatch_to_extension_op(op, DatetimeArray(left), right)
697-
return construct_result(left, result, index=left.index, name=res_name)
698-
699-
elif is_extension_array_dtype(left) or (
700-
is_extension_array_dtype(right) and not is_scalar(right)
701-
):
702-
# GH#22378 disallow scalar to exclude e.g. "category", "Int64"
713+
if should_extension_dispatch(left, right):
703714
result = dispatch_to_extension_op(op, left, right)
704-
return construct_result(left, result, index=left.index, name=res_name)
705715

706-
elif is_timedelta64_dtype(left):
707-
from pandas.core.arrays import TimedeltaArray
708-
709-
result = dispatch_to_extension_op(op, TimedeltaArray(left), right)
710-
return construct_result(left, result, index=left.index, name=res_name)
711-
712-
elif is_timedelta64_dtype(right):
713-
# We should only get here with non-scalar values for right
714-
# upcast by maybe_upcast_for_op
716+
elif is_timedelta64_dtype(right) or isinstance(
717+
right, (ABCDatetimeArray, ABCDatetimeIndex)
718+
):
719+
# We should only get here with td64 right with non-scalar values
720+
# for right upcast by maybe_upcast_for_op
715721
assert not isinstance(right, (np.timedelta64, np.ndarray))
716-
717722
result = op(left._values, right)
718723

719-
# We do not pass dtype to ensure that the Series constructor
720-
# does inference in the case where `result` has object-dtype.
721-
return construct_result(left, result, index=left.index, name=res_name)
722-
723-
elif isinstance(right, (ABCDatetimeArray, ABCDatetimeIndex)):
724-
result = op(left._values, right)
725-
return construct_result(left, result, index=left.index, name=res_name)
724+
else:
725+
lvalues = extract_array(left, extract_numpy=True)
726+
rvalues = extract_array(right, extract_numpy=True)
726727

727-
lvalues = left.values
728-
rvalues = right
729-
if isinstance(rvalues, (ABCSeries, ABCIndexClass)):
730-
rvalues = rvalues._values
728+
with np.errstate(all="ignore"):
729+
result = na_op(lvalues, rvalues)
731730

732-
with np.errstate(all="ignore"):
733-
result = na_op(lvalues, rvalues)
734-
return construct_result(
735-
left, result, index=left.index, name=res_name, dtype=None
736-
)
731+
# We do not pass dtype to ensure that the Series constructor
732+
# does inference in the case where `result` has object-dtype.
733+
return construct_result(left, result, index=left.index, name=res_name)
737734

738735
wrapper.__name__ = op_name
739736
return wrapper

pandas/tests/arrays/categorical/test_operators.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,9 @@ def test_numeric_like_ops(self):
349349
("__mul__", r"\*"),
350350
("__truediv__", "/"),
351351
]:
352-
msg = r"Series cannot perform the operation {}".format(str_rep)
352+
msg = r"Series cannot perform the operation {}|unsupported operand".format(
353+
str_rep
354+
)
353355
with pytest.raises(TypeError, match=msg):
354356
getattr(df, op)(df)
355357

@@ -375,7 +377,9 @@ def test_numeric_like_ops(self):
375377
("__mul__", r"\*"),
376378
("__truediv__", "/"),
377379
]:
378-
msg = r"Series cannot perform the operation {}".format(str_rep)
380+
msg = r"Series cannot perform the operation {}|unsupported operand".format(
381+
str_rep
382+
)
379383
with pytest.raises(TypeError, match=msg):
380384
getattr(s, op)(2)
381385

pandas/tests/extension/test_categorical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
211211

212212
def test_add_series_with_extension_array(self, data):
213213
ser = pd.Series(data)
214-
with pytest.raises(TypeError, match="cannot perform"):
214+
with pytest.raises(TypeError, match="cannot perform|unsupported operand"):
215215
ser + data
216216

217217
def test_divmod_series_array(self):

0 commit comments

Comments
 (0)