Skip to content

REF: do extract_array earlier in series arith/comparison ops #28066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 2, 2019
92 changes: 57 additions & 35 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
ABCIndexClass,
ABCSeries,
ABCSparseSeries,
ABCTimedeltaArray,
ABCTimedeltaIndex,
)
from pandas.core.dtypes.missing import isna, notna

import pandas as pd
from pandas._typing import ArrayLike
from pandas.core.construction import array, extract_array
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY, define_na_arithmetic_op
Expand Down Expand Up @@ -157,12 +158,13 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
if isna(obj):
# wrapping timedelta64("NaT") in Timedelta returns NaT,
# which would incorrectly be treated as a datetime-NaT, so
# we broadcast and wrap in a Series
# we broadcast and wrap in a TimedeltaArray
obj = obj.astype("timedelta64[ns]")
right = np.broadcast_to(obj, shape)

# Note: we use Series instead of TimedeltaIndex to avoid having
# to worry about catching NullFrequencyError.
return pd.Series(right)
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray(right)

# In particular non-nanosecond timedelta64 needs to be cast to
# nanoseconds, or else we get undesired behavior like
Expand All @@ -173,7 +175,9 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
# GH#22390 Unfortunately we need to special-case right-hand
# timedelta64 dtypes because numpy casts integer dtypes to
# timedelta64 when operating with timedelta64
return pd.TimedeltaIndex(obj)
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray._from_sequence(obj)
return obj


Expand Down Expand Up @@ -520,13 +524,29 @@ def column_op(a, b):
return result


def dispatch_to_extension_op(op, left, right):
def dispatch_to_extension_op(op, left, right, keep_null_freq: bool = False):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simonjayhawkins do we have a way of typing left and right as "not a Series or Index"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u type left here (EA / np.ndarray)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simonjayhawkins do we have a way of typing left and right as "not a Series or Index"?

i'm not aware of being able to exclude types.

if a particular type raises, (in this case TYpeError?) then maybe could use overloads with a return type of NoReturn https://mypy.readthedocs.io/en/latest/more_types.html#the-noreturn-type (New in version 3.5.4)

could maybe use the following pattern to allow checking with older Python...

if TYPE_CHECKING:
    from typing import NoReturn
else:
    NoReturn = None

"""
Assume that left or right is a Series backed by an ExtensionArray,
apply the operator defined by op.

Parameters
----------
op : binary operator
left : ExtensionArray or np.ndarray
right : object
keep_null_freq : bool, default False
Whether to re-raise a NullFrequencyError unchanged, as opposed to
catching and raising TypeError.

Returns
-------
ExtensionArray or np.ndarray
2-tuple of these if op is divmod or rdivmod
"""
# NB: left and right should already be unboxed, so neither should be
# a Series or Index.

if left.dtype.kind in "mM":
if left.dtype.kind in "mM" and isinstance(left, np.ndarray):
# We need to cast datetime64 and timedelta64 ndarrays to
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in
# PandasArray as that behaves poorly with e.g. IntegerArray.
Expand All @@ -535,15 +555,13 @@ def dispatch_to_extension_op(op, left, right):
# The op calls will raise TypeError if the op is not defined
# on the ExtensionArray

# unbox Series and Index to arrays
new_left = extract_array(left, extract_numpy=True)
new_right = extract_array(right, extract_numpy=True)

try:
res_values = op(new_left, new_right)
res_values = op(left, right)
except NullFrequencyError:
# DatetimeIndex and TimedeltaIndex with freq == None raise ValueError
# on add/sub of integers (or int-like). We re-raise as a TypeError.
if keep_null_freq:
raise
raise TypeError(
"incompatible type for a datetime/timedelta "
"operation [{name}]".format(name=op.__name__)
Expand Down Expand Up @@ -615,25 +633,29 @@ def wrapper(left, right):
if isinstance(right, ABCDataFrame):
return NotImplemented

keep_null_freq = isinstance(
right,
(ABCDatetimeIndex, ABCDatetimeArray, ABCTimedeltaIndex, ABCTimedeltaArray),
)

left, right = _align_method_SERIES(left, right)
res_name = get_op_result_name(left, right)
right = maybe_upcast_for_op(right, left.shape)

if should_extension_dispatch(left, right):
result = dispatch_to_extension_op(op, left, right)
lvalues = extract_array(left, extract_numpy=True)
rvalues = extract_array(right, extract_numpy=True)

elif is_timedelta64_dtype(right) or isinstance(
right, (ABCDatetimeArray, ABCDatetimeIndex)
):
# We should only get here with td64 right with non-scalar values
# for right upcast by maybe_upcast_for_op
assert not isinstance(right, (np.timedelta64, np.ndarray))
result = op(left._values, right)
rvalues = maybe_upcast_for_op(rvalues, lvalues.shape)

else:
lvalues = extract_array(left, extract_numpy=True)
rvalues = extract_array(right, extract_numpy=True)
if should_extension_dispatch(left, rvalues):
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq)

elif is_timedelta64_dtype(rvalues) or isinstance(rvalues, ABCDatetimeArray):
# We should only get here with td64 rvalues with non-scalar values
# for rvalues upcast by maybe_upcast_for_op
assert not isinstance(rvalues, (np.timedelta64, np.ndarray))
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq)

else:
with np.errstate(all="ignore"):
result = na_op(lvalues, rvalues)

Expand Down Expand Up @@ -708,25 +730,25 @@ def wrapper(self, other, axis=None):
if len(self) != len(other):
raise ValueError("Lengths must match to compare")

if should_extension_dispatch(self, other):
res_values = dispatch_to_extension_op(op, self, other)
lvalues = extract_array(self, extract_numpy=True)
rvalues = extract_array(other, extract_numpy=True)

elif is_scalar(other) and isna(other):
if should_extension_dispatch(lvalues, rvalues):
res_values = dispatch_to_extension_op(op, lvalues, rvalues)

elif is_scalar(rvalues) and isna(rvalues):
# numpy does not like comparisons vs None
if op is operator.ne:
res_values = np.ones(len(self), dtype=bool)
res_values = np.ones(len(lvalues), dtype=bool)
else:
res_values = np.zeros(len(self), dtype=bool)
res_values = np.zeros(len(lvalues), dtype=bool)

else:
lvalues = extract_array(self, extract_numpy=True)
rvalues = extract_array(other, extract_numpy=True)

with np.errstate(all="ignore"):
res_values = na_op(lvalues, rvalues)
if is_scalar(res_values):
raise TypeError(
"Could not compare {typ} type with Series".format(typ=type(other))
"Could not compare {typ} type with Series".format(typ=type(rvalues))
)

result = self._constructor(res_values, index=self.index)
Expand Down