|
47 | 47 |
|
48 | 48 | import pandas as pd
|
49 | 49 | from pandas._typing import ArrayLike
|
50 |
| -from pandas.core.construction import extract_array |
| 50 | +from pandas.core.construction import array, extract_array |
51 | 51 | from pandas.core.ops import missing
|
52 | 52 | from pandas.core.ops.docstrings import (
|
53 | 53 | _arith_doc_FRAME,
|
@@ -460,6 +460,33 @@ def masked_arith_op(x, y, op):
|
460 | 460 | # Dispatch logic
|
461 | 461 |
|
462 | 462 |
|
| 463 | +def should_extension_dispatch(left: ABCSeries, right: Any) -> bool: |
| 464 | + """ |
| 465 | + Identify cases where Series operation should use dispatch_to_extension_op. |
| 466 | +
|
| 467 | + Parameters |
| 468 | + ---------- |
| 469 | + left : Series |
| 470 | + right : object |
| 471 | +
|
| 472 | + Returns |
| 473 | + ------- |
| 474 | + bool |
| 475 | + """ |
| 476 | + if ( |
| 477 | + is_extension_array_dtype(left.dtype) |
| 478 | + or is_datetime64_dtype(left.dtype) |
| 479 | + or is_timedelta64_dtype(left.dtype) |
| 480 | + ): |
| 481 | + return True |
| 482 | + |
| 483 | + if is_extension_array_dtype(right) and not is_scalar(right): |
| 484 | + # GH#22378 disallow scalar to exclude e.g. "category", "Int64" |
| 485 | + return True |
| 486 | + |
| 487 | + return False |
| 488 | + |
| 489 | + |
463 | 490 | def should_series_dispatch(left, right, op):
|
464 | 491 | """
|
465 | 492 | Identify cases where a DataFrame operation should dispatch to its
|
@@ -564,19 +591,18 @@ def dispatch_to_extension_op(op, left, right):
|
564 | 591 | apply the operator defined by op.
|
565 | 592 | """
|
566 | 593 |
|
| 594 | + if left.dtype.kind in "mM": |
| 595 | + # We need to cast datetime64 and timedelta64 ndarrays to |
| 596 | + # DatetimeArray/TimedeltaArray. But we avoid wrapping others in |
| 597 | + # PandasArray as that behaves poorly with e.g. IntegerArray. |
| 598 | + left = array(left) |
| 599 | + |
567 | 600 | # The op calls will raise TypeError if the op is not defined
|
568 | 601 | # on the ExtensionArray
|
569 | 602 |
|
570 | 603 | # unbox Series and Index to arrays
|
571 |
| - if isinstance(left, (ABCSeries, ABCIndexClass)): |
572 |
| - new_left = left._values |
573 |
| - else: |
574 |
| - new_left = left |
575 |
| - |
576 |
| - if isinstance(right, (ABCSeries, ABCIndexClass)): |
577 |
| - new_right = right._values |
578 |
| - else: |
579 |
| - new_right = right |
| 604 | + new_left = extract_array(left, extract_numpy=True) |
| 605 | + new_right = extract_array(right, extract_numpy=True) |
580 | 606 |
|
581 | 607 | try:
|
582 | 608 | res_values = op(new_left, new_right)
|
@@ -684,56 +710,27 @@ def wrapper(left, right):
|
684 | 710 | res_name = get_op_result_name(left, right)
|
685 | 711 | right = maybe_upcast_for_op(right, left.shape)
|
686 | 712 |
|
687 |
| - if is_categorical_dtype(left): |
688 |
| - raise TypeError( |
689 |
| - "{typ} cannot perform the operation " |
690 |
| - "{op}".format(typ=type(left).__name__, op=str_rep) |
691 |
| - ) |
692 |
| - |
693 |
| - elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left): |
694 |
| - from pandas.core.arrays import DatetimeArray |
695 |
| - |
696 |
| - result = dispatch_to_extension_op(op, DatetimeArray(left), right) |
697 |
| - return construct_result(left, result, index=left.index, name=res_name) |
698 |
| - |
699 |
| - elif is_extension_array_dtype(left) or ( |
700 |
| - is_extension_array_dtype(right) and not is_scalar(right) |
701 |
| - ): |
702 |
| - # GH#22378 disallow scalar to exclude e.g. "category", "Int64" |
| 713 | + if should_extension_dispatch(left, right): |
703 | 714 | result = dispatch_to_extension_op(op, left, right)
|
704 |
| - return construct_result(left, result, index=left.index, name=res_name) |
705 | 715 |
|
706 |
| - elif is_timedelta64_dtype(left): |
707 |
| - from pandas.core.arrays import TimedeltaArray |
708 |
| - |
709 |
| - result = dispatch_to_extension_op(op, TimedeltaArray(left), right) |
710 |
| - return construct_result(left, result, index=left.index, name=res_name) |
711 |
| - |
712 |
| - elif is_timedelta64_dtype(right): |
713 |
| - # We should only get here with non-scalar values for right |
714 |
| - # upcast by maybe_upcast_for_op |
| 716 | + elif is_timedelta64_dtype(right) or isinstance( |
| 717 | + right, (ABCDatetimeArray, ABCDatetimeIndex) |
| 718 | + ): |
| 719 | + # We should only get here with td64 right with non-scalar values |
| 720 | + # for right upcast by maybe_upcast_for_op |
715 | 721 | assert not isinstance(right, (np.timedelta64, np.ndarray))
|
716 |
| - |
717 | 722 | result = op(left._values, right)
|
718 | 723 |
|
719 |
| - # We do not pass dtype to ensure that the Series constructor |
720 |
| - # does inference in the case where `result` has object-dtype. |
721 |
| - return construct_result(left, result, index=left.index, name=res_name) |
722 |
| - |
723 |
| - elif isinstance(right, (ABCDatetimeArray, ABCDatetimeIndex)): |
724 |
| - result = op(left._values, right) |
725 |
| - return construct_result(left, result, index=left.index, name=res_name) |
| 724 | + else: |
| 725 | + lvalues = extract_array(left, extract_numpy=True) |
| 726 | + rvalues = extract_array(right, extract_numpy=True) |
726 | 727 |
|
727 |
| - lvalues = left.values |
728 |
| - rvalues = right |
729 |
| - if isinstance(rvalues, (ABCSeries, ABCIndexClass)): |
730 |
| - rvalues = rvalues._values |
| 728 | + with np.errstate(all="ignore"): |
| 729 | + result = na_op(lvalues, rvalues) |
731 | 730 |
|
732 |
| - with np.errstate(all="ignore"): |
733 |
| - result = na_op(lvalues, rvalues) |
734 |
| - return construct_result( |
735 |
| - left, result, index=left.index, name=res_name, dtype=None |
736 |
| - ) |
| 731 | + # We do not pass dtype to ensure that the Series constructor |
| 732 | + # does inference in the case where `result` has object-dtype. |
| 733 | + return construct_result(left, result, index=left.index, name=res_name) |
737 | 734 |
|
738 | 735 | wrapper.__name__ = op_name
|
739 | 736 | return wrapper
|
|
0 commit comments