Skip to content

Commit f731f6d

Browse files
authored
REF: simplify broadcasting code (#33565)
1 parent 2a4ec05 commit f731f6d

File tree

3 files changed

+13
-63
lines changed

3 files changed

+13
-63
lines changed

pandas/core/internals/managers.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -390,20 +390,22 @@ def apply(self: T, f, align_keys=None, **kwargs) -> T:
390390
if f == "where":
391391
align_copy = True
392392

393-
aligned_args = {
394-
k: kwargs[k]
395-
for k in align_keys
396-
if isinstance(kwargs[k], (ABCSeries, ABCDataFrame))
397-
}
393+
aligned_args = {k: kwargs[k] for k in align_keys}
398394

399395
for b in self.blocks:
400396

401397
if aligned_args:
402398
b_items = self.items[b.mgr_locs.indexer]
403399

404400
for k, obj in aligned_args.items():
405-
axis = obj._info_axis_number
406-
kwargs[k] = obj.reindex(b_items, axis=axis, copy=align_copy)._values
401+
if isinstance(obj, (ABCSeries, ABCDataFrame)):
402+
axis = obj._info_axis_number
403+
kwargs[k] = obj.reindex(
404+
b_items, axis=axis, copy=align_copy
405+
)._values
406+
else:
407+
# otherwise we have an ndarray
408+
kwargs[k] = obj[b.mgr_locs.indexer]
407409

408410
if callable(f):
409411
applied = b.apply(f, **kwargs)

pandas/core/ops/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,13 @@ def _combine_series_frame(left, right, func, axis: int, str_rep: str):
510510
if axis == 0:
511511
values = right._values
512512
if isinstance(values, np.ndarray):
513+
# TODO(EA2D): no need to special-case with 2D EAs
513514
# We can operate block-wise
514515
values = values.reshape(-1, 1)
516+
values = np.broadcast_to(values, left.shape)
515517

516518
array_op = get_array_op(func, str_rep=str_rep)
517-
bm = left._mgr.apply(array_op, right=values.T)
519+
bm = left._mgr.apply(array_op, right=values.T, align_keys=["right"])
518520
return type(left)(bm)
519521

520522
new_data = dispatch_to_series(left, right, func)

pandas/core/ops/array_ops.py

+1-55
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,7 @@ def masked_arith_op(x: np.ndarray, y, op):
7575
result = np.empty(x.size, dtype=dtype)
7676

7777
if len(x) != len(y):
78-
if not _can_broadcast(x, y):
79-
raise ValueError(x.shape, y.shape)
80-
81-
# Call notna on pre-broadcasted y for performance
82-
ymask = notna(y)
83-
y = np.broadcast_to(y, x.shape)
84-
ymask = np.broadcast_to(ymask, x.shape)
85-
78+
raise ValueError(x.shape, y.shape)
8679
else:
8780
ymask = notna(y)
8881

@@ -211,51 +204,6 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
211204
return res_values
212205

213206

214-
def _broadcast_comparison_op(lvalues, rvalues, op) -> np.ndarray:
215-
"""
216-
Broadcast a comparison operation between two 2D arrays.
217-
218-
Parameters
219-
----------
220-
lvalues : np.ndarray or ExtensionArray
221-
rvalues : np.ndarray or ExtensionArray
222-
223-
Returns
224-
-------
225-
np.ndarray[bool]
226-
"""
227-
if isinstance(rvalues, np.ndarray):
228-
rvalues = np.broadcast_to(rvalues, lvalues.shape)
229-
result = comparison_op(lvalues, rvalues, op)
230-
else:
231-
result = np.empty(lvalues.shape, dtype=bool)
232-
for i in range(len(lvalues)):
233-
result[i, :] = comparison_op(lvalues[i], rvalues[:, 0], op)
234-
return result
235-
236-
237-
def _can_broadcast(lvalues, rvalues) -> bool:
238-
"""
239-
Check if we can broadcast rvalues to match the shape of lvalues.
240-
241-
Parameters
242-
----------
243-
lvalues : np.ndarray or ExtensionArray
244-
rvalues : np.ndarray or ExtensionArray
245-
246-
Returns
247-
-------
248-
bool
249-
"""
250-
# We assume that lengths dont match
251-
if lvalues.ndim == rvalues.ndim == 2:
252-
# See if we can broadcast unambiguously
253-
if lvalues.shape[1] == rvalues.shape[-1]:
254-
if rvalues.shape[0] == 1:
255-
return True
256-
return False
257-
258-
259207
def comparison_op(
260208
left: ArrayLike, right: Any, op, str_rep: Optional[str] = None,
261209
) -> ArrayLike:
@@ -287,8 +235,6 @@ def comparison_op(
287235
# We are not catching all listlikes here (e.g. frozenset, tuple)
288236
# The ambiguous case is object-dtype. See GH#27803
289237
if len(lvalues) != len(rvalues):
290-
if _can_broadcast(lvalues, rvalues):
291-
return _broadcast_comparison_op(lvalues, rvalues, op)
292238
raise ValueError(
293239
"Lengths must match to compare", lvalues.shape, rvalues.shape
294240
)

0 commit comments

Comments
 (0)