Skip to content

REF: prepare (upcast) scalar before dispatching to arithmetic array ops #40479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 1 addition & 0 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def _cmp_method(self, other, op):
if isinstance(other, PandasArray):
other = other._ndarray

other = ops.maybe_prepare_scalar_for_op(other, (len(self),))
pd_op = ops.get_array_op(op)
other = ensure_wrapped_if_datetimelike(other)
with np.errstate(all="ignore"):
Expand Down
1 change: 1 addition & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6804,6 +6804,7 @@ def _arith_method(self, other, op):
return ops.frame_arith_method_with_reindex(self, other, op)

axis = 1 # only relevant for Series other case
other = ops.maybe_prepare_scalar_for_op(other, (self.shape[axis],))

self, other = ops.align_method_FRAME(self, other, axis, flex=True, level=None)

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
comparison_op,
get_array_op,
logical_op,
maybe_prepare_scalar_for_op,
)
from pandas.core.ops.common import ( # noqa:F401
get_op_result_name,
Expand Down Expand Up @@ -428,6 +429,7 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):

axis = self._get_axis_number(axis) if axis is not None else 1

other = maybe_prepare_scalar_for_op(other, (self.shape[axis],))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why self.shape[axis] instead of self.shape?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because self is a dataframe, and we want a 1D shape (rows or columns depending on the axis)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing shape only matters when we have a scalar, in which case broadcasting to self.shape is simpler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, changed to self.shape

self, other = align_method_FRAME(self, other, axis, flex=True, level=level)

if isinstance(other, ABCDataFrame):
Expand Down
7 changes: 3 additions & 4 deletions pandas/core/ops/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,11 @@ def arithmetic_op(left: ArrayLike, right: Any, op):
ndarray or ExtensionArray
Or a 2-tuple of these in the case of divmod or rdivmod.
"""

# NB: We assume that extract_array and ensure_wrapped_if_datetimelike
# has already been called on `left` and `right`.
# have already been called on `left` and `right`,
# and `maybe_prepare_scalar_for_op` has already been called on `right`
# We need to special-case datetime64/timedelta64 dtypes (e.g. because numpy
# casts integer dtypes to timedelta64 when operating with timedelta64 - GH#22390)
right = _maybe_upcast_for_op(right, left.shape)

if should_extension_dispatch(left, right) or isinstance(right, Timedelta):
# Timedelta is included because numexpr will fail on it, see GH#31457
Expand Down Expand Up @@ -419,7 +418,7 @@ def get_array_op(op):
raise NotImplementedError(op_name)


def _maybe_upcast_for_op(obj, shape: Shape):
def maybe_prepare_scalar_for_op(obj, shape: Shape):
"""
Cast non-pandas objects to pandas types to unify behavior of arithmetic
and comparison operations.
Expand Down
1 change: 1 addition & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5315,6 +5315,7 @@ def _arith_method(self, other, op):

lvalues = self._values
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
rvalues = ensure_wrapped_if_datetimelike(rvalues)

with np.errstate(all="ignore"):
Expand Down