Skip to content

Commit f99ef5f

Browse files
jorisvandenbosscheyeshsurya
authored andcommitted
REF: prepare (upcast) scalar before dispatching to arithmetic array ops (pandas-dev#40479)
1 parent 1dd3e6e commit f99ef5f

File tree

5 files changed

+8
-4
lines changed

5 files changed

+8
-4
lines changed

pandas/core/arrays/numpy_.py

+1
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def _cmp_method(self, other, op):
395395
if isinstance(other, PandasArray):
396396
other = other._ndarray
397397

398+
other = ops.maybe_prepare_scalar_for_op(other, (len(self),))
398399
pd_op = ops.get_array_op(op)
399400
other = ensure_wrapped_if_datetimelike(other)
400401
with np.errstate(all="ignore"):

pandas/core/frame.py

+1
Original file line numberDiff line numberDiff line change
@@ -6804,6 +6804,7 @@ def _arith_method(self, other, op):
68046804
return ops.frame_arith_method_with_reindex(self, other, op)
68056805

68066806
axis = 1 # only relevant for Series other case
6807+
other = ops.maybe_prepare_scalar_for_op(other, (self.shape[axis],))
68076808

68086809
self, other = ops.align_method_FRAME(self, other, axis, flex=True, level=None)
68096810

pandas/core/ops/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
comparison_op,
3636
get_array_op,
3737
logical_op,
38+
maybe_prepare_scalar_for_op,
3839
)
3940
from pandas.core.ops.common import ( # noqa:F401
4041
get_op_result_name,
@@ -428,6 +429,7 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
428429

429430
axis = self._get_axis_number(axis) if axis is not None else 1
430431

432+
other = maybe_prepare_scalar_for_op(other, self.shape)
431433
self, other = align_method_FRAME(self, other, axis, flex=True, level=level)
432434

433435
if isinstance(other, ABCDataFrame):

pandas/core/ops/array_ops.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,11 @@ def arithmetic_op(left: ArrayLike, right: Any, op):
202202
ndarray or ExtensionArray
203203
Or a 2-tuple of these in the case of divmod or rdivmod.
204204
"""
205-
206205
# NB: We assume that extract_array and ensure_wrapped_if_datetimelike
207-
# has already been called on `left` and `right`.
206+
# have already been called on `left` and `right`,
207+
# and `maybe_prepare_scalar_for_op` has already been called on `right`
208208
# We need to special-case datetime64/timedelta64 dtypes (e.g. because numpy
209209
# casts integer dtypes to timedelta64 when operating with timedelta64 - GH#22390)
210-
right = _maybe_upcast_for_op(right, left.shape)
211210

212211
if (
213212
should_extension_dispatch(left, right)
@@ -439,7 +438,7 @@ def get_array_op(op):
439438
raise NotImplementedError(op_name)
440439

441440

442-
def _maybe_upcast_for_op(obj, shape: Shape):
441+
def maybe_prepare_scalar_for_op(obj, shape: Shape):
443442
"""
444443
Cast non-pandas objects to pandas types to unify behavior of arithmetic
445444
and comparison operations.

pandas/core/series.py

+1
Original file line numberDiff line numberDiff line change
@@ -5315,6 +5315,7 @@ def _arith_method(self, other, op):
53155315

53165316
lvalues = self._values
53175317
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
5318+
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
53185319
rvalues = ensure_wrapped_if_datetimelike(rvalues)
53195320

53205321
with np.errstate(all="ignore"):

0 commit comments

Comments
 (0)