Skip to content

Commit 2cd7888

Browse files
jbrockmendeljreback
authored andcommitted
REF: do extract_array earlier in series arith/comparison ops (#28066)
1 parent 2aeed3f commit 2cd7888

File tree

1 file changed

+63
-37
lines changed

1 file changed

+63
-37
lines changed

pandas/core/ops/__init__.py

+63-37
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
import datetime
77
import operator
8-
from typing import Any, Callable, Tuple
8+
from typing import Any, Callable, Tuple, Union
99

1010
import numpy as np
1111

@@ -34,10 +34,11 @@
3434
ABCIndexClass,
3535
ABCSeries,
3636
ABCSparseSeries,
37+
ABCTimedeltaArray,
38+
ABCTimedeltaIndex,
3739
)
3840
from pandas.core.dtypes.missing import isna, notna
3941

40-
import pandas as pd
4142
from pandas._typing import ArrayLike
4243
from pandas.core.construction import array, extract_array
4344
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY, define_na_arithmetic_op
@@ -148,6 +149,8 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
148149
Be careful to call this *after* determining the `name` attribute to be
149150
attached to the result of the arithmetic operation.
150151
"""
152+
from pandas.core.arrays import TimedeltaArray
153+
151154
if type(obj) is datetime.timedelta:
152155
# GH#22390 cast up to Timedelta to rely on Timedelta
153156
# implementation; otherwise operation against numeric-dtype
@@ -157,12 +160,10 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
157160
if isna(obj):
158161
# wrapping timedelta64("NaT") in Timedelta returns NaT,
159162
# which would incorrectly be treated as a datetime-NaT, so
160-
# we broadcast and wrap in a Series
163+
# we broadcast and wrap in a TimedeltaArray
164+
obj = obj.astype("timedelta64[ns]")
161165
right = np.broadcast_to(obj, shape)
162-
163-
# Note: we use Series instead of TimedeltaIndex to avoid having
164-
# to worry about catching NullFrequencyError.
165-
return pd.Series(right)
166+
return TimedeltaArray(right)
166167

167168
# In particular non-nanosecond timedelta64 needs to be cast to
168169
# nanoseconds, or else we get undesired behavior like
@@ -173,7 +174,7 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
173174
# GH#22390 Unfortunately we need to special-case right-hand
174175
# timedelta64 dtypes because numpy casts integer dtypes to
175176
# timedelta64 when operating with timedelta64
176-
return pd.TimedeltaIndex(obj)
177+
return TimedeltaArray._from_sequence(obj)
177178
return obj
178179

179180

@@ -520,13 +521,34 @@ def column_op(a, b):
520521
return result
521522

522523

523-
def dispatch_to_extension_op(op, left, right):
524+
def dispatch_to_extension_op(
525+
op,
526+
left: Union[ABCExtensionArray, np.ndarray],
527+
right: Any,
528+
keep_null_freq: bool = False,
529+
):
524530
"""
525531
Assume that left or right is a Series backed by an ExtensionArray,
526532
apply the operator defined by op.
533+
534+
Parameters
535+
----------
536+
op : binary operator
537+
left : ExtensionArray or np.ndarray
538+
right : object
539+
keep_null_freq : bool, default False
540+
Whether to re-raise a NullFrequencyError unchanged, as opposed to
541+
catching and raising TypeError.
542+
543+
Returns
544+
-------
545+
ExtensionArray or np.ndarray
546+
2-tuple of these if op is divmod or rdivmod
527547
"""
548+
# NB: left and right should already be unboxed, so neither should be
549+
# a Series or Index.
528550

529-
if left.dtype.kind in "mM":
551+
if left.dtype.kind in "mM" and isinstance(left, np.ndarray):
530552
# We need to cast datetime64 and timedelta64 ndarrays to
531553
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in
532554
# PandasArray as that behaves poorly with e.g. IntegerArray.
@@ -535,15 +557,15 @@ def dispatch_to_extension_op(op, left, right):
535557
# The op calls will raise TypeError if the op is not defined
536558
# on the ExtensionArray
537559

538-
# unbox Series and Index to arrays
539-
new_left = extract_array(left, extract_numpy=True)
540-
new_right = extract_array(right, extract_numpy=True)
541-
542560
try:
543-
res_values = op(new_left, new_right)
561+
res_values = op(left, right)
544562
except NullFrequencyError:
545563
# DatetimeIndex and TimedeltaIndex with freq == None raise ValueError
546564
# on add/sub of integers (or int-like). We re-raise as a TypeError.
565+
if keep_null_freq:
566+
# TODO: remove keep_null_freq after Timestamp+int deprecation
567+
# GH#22535 is enforced
568+
raise
547569
raise TypeError(
548570
"incompatible type for a datetime/timedelta "
549571
"operation [{name}]".format(name=op.__name__)
@@ -615,25 +637,29 @@ def wrapper(left, right):
615637
if isinstance(right, ABCDataFrame):
616638
return NotImplemented
617639

640+
keep_null_freq = isinstance(
641+
right,
642+
(ABCDatetimeIndex, ABCDatetimeArray, ABCTimedeltaIndex, ABCTimedeltaArray),
643+
)
644+
618645
left, right = _align_method_SERIES(left, right)
619646
res_name = get_op_result_name(left, right)
620-
right = maybe_upcast_for_op(right, left.shape)
621647

622-
if should_extension_dispatch(left, right):
623-
result = dispatch_to_extension_op(op, left, right)
648+
lvalues = extract_array(left, extract_numpy=True)
649+
rvalues = extract_array(right, extract_numpy=True)
624650

625-
elif is_timedelta64_dtype(right) or isinstance(
626-
right, (ABCDatetimeArray, ABCDatetimeIndex)
627-
):
628-
# We should only get here with td64 right with non-scalar values
629-
# for right upcast by maybe_upcast_for_op
630-
assert not isinstance(right, (np.timedelta64, np.ndarray))
631-
result = op(left._values, right)
651+
rvalues = maybe_upcast_for_op(rvalues, lvalues.shape)
632652

633-
else:
634-
lvalues = extract_array(left, extract_numpy=True)
635-
rvalues = extract_array(right, extract_numpy=True)
653+
if should_extension_dispatch(lvalues, rvalues):
654+
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq)
655+
656+
elif is_timedelta64_dtype(rvalues) or isinstance(rvalues, ABCDatetimeArray):
657+
# We should only get here with td64 rvalues with non-scalar values
658+
# for rvalues upcast by maybe_upcast_for_op
659+
assert not isinstance(rvalues, (np.timedelta64, np.ndarray))
660+
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq)
636661

662+
else:
637663
with np.errstate(all="ignore"):
638664
result = na_op(lvalues, rvalues)
639665

@@ -708,25 +734,25 @@ def wrapper(self, other, axis=None):
708734
if len(self) != len(other):
709735
raise ValueError("Lengths must match to compare")
710736

711-
if should_extension_dispatch(self, other):
712-
res_values = dispatch_to_extension_op(op, self, other)
737+
lvalues = extract_array(self, extract_numpy=True)
738+
rvalues = extract_array(other, extract_numpy=True)
713739

714-
elif is_scalar(other) and isna(other):
740+
if should_extension_dispatch(lvalues, rvalues):
741+
res_values = dispatch_to_extension_op(op, lvalues, rvalues)
742+
743+
elif is_scalar(rvalues) and isna(rvalues):
715744
# numpy does not like comparisons vs None
716745
if op is operator.ne:
717-
res_values = np.ones(len(self), dtype=bool)
746+
res_values = np.ones(len(lvalues), dtype=bool)
718747
else:
719-
res_values = np.zeros(len(self), dtype=bool)
748+
res_values = np.zeros(len(lvalues), dtype=bool)
720749

721750
else:
722-
lvalues = extract_array(self, extract_numpy=True)
723-
rvalues = extract_array(other, extract_numpy=True)
724-
725751
with np.errstate(all="ignore"):
726752
res_values = na_op(lvalues, rvalues)
727753
if is_scalar(res_values):
728754
raise TypeError(
729-
"Could not compare {typ} type with Series".format(typ=type(other))
755+
"Could not compare {typ} type with Series".format(typ=type(rvalues))
730756
)
731757

732758
result = self._constructor(res_values, index=self.index)

0 commit comments

Comments
 (0)