Skip to content

REF: prepare (upcast) scalar before dispatching to arithmetic array ops #40479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 1 addition & 0 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def _cmp_method(self, other, op):
if isinstance(other, PandasArray):
other = other._ndarray

other = ops.maybe_prepare_scalar_for_op(other, (len(self),))
pd_op = ops.get_array_op(op)
other = ensure_wrapped_if_datetimelike(other)
with np.errstate(all="ignore"):
Expand Down
1 change: 1 addition & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6804,6 +6804,7 @@ def _arith_method(self, other, op):
return ops.frame_arith_method_with_reindex(self, other, op)

axis = 1 # only relevant for Series other case
other = ops.maybe_prepare_scalar_for_op(other, (self.shape[axis],))

self, other = ops.align_method_FRAME(self, other, axis, flex=True, level=None)

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
comparison_op,
get_array_op,
logical_op,
maybe_prepare_scalar_for_op,
)
from pandas.core.ops.common import ( # noqa:F401
get_op_result_name,
Expand Down Expand Up @@ -428,6 +429,7 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):

axis = self._get_axis_number(axis) if axis is not None else 1

other = maybe_prepare_scalar_for_op(other, self.shape)
self, other = align_method_FRAME(self, other, axis, flex=True, level=level)

if isinstance(other, ABCDataFrame):
Expand Down
7 changes: 3 additions & 4 deletions pandas/core/ops/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,11 @@ def arithmetic_op(left: ArrayLike, right: Any, op):
ndarray or ExtensionArray
Or a 2-tuple of these in the case of divmod or rdivmod.
"""

# NB: We assume that extract_array and ensure_wrapped_if_datetimelike
# has already been called on `left` and `right`.
# have already been called on `left` and `right`,
# and `maybe_prepare_scalar_for_op` has already been called on `right`
# We need to special-case datetime64/timedelta64 dtypes (e.g. because numpy
# casts integer dtypes to timedelta64 when operating with timedelta64 - GH#22390)
right = _maybe_upcast_for_op(right, left.shape)

if (
should_extension_dispatch(left, right)
Expand Down Expand Up @@ -439,7 +438,7 @@ def get_array_op(op):
raise NotImplementedError(op_name)


def _maybe_upcast_for_op(obj, shape: Shape):
def maybe_prepare_scalar_for_op(obj, shape: Shape):
"""
Cast non-pandas objects to pandas types to unify behavior of arithmetic
and comparison operations.
Expand Down
1 change: 1 addition & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5315,6 +5315,7 @@ def _arith_method(self, other, op):

lvalues = self._values
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
rvalues = ensure_wrapped_if_datetimelike(rvalues)

with np.errstate(all="ignore"):
Expand Down