Skip to content

Commit 19ae087

Browse files
authored
PERF: do DataFrame.op(series, axis=0) blockwise (#31296)
1 parent 957fc3c commit 19ae087

File tree

8 files changed

+152
-27
lines changed

8 files changed

+152
-27
lines changed

asv_bench/benchmarks/arithmetic.py

+30
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ def time_frame_op_with_scalar(self, dtype, scalar, op):
5050
op(self.df, scalar)
5151

5252

53+
class MixedFrameWithSeriesAxis0:
54+
params = [
55+
[
56+
"eq",
57+
"ne",
58+
"lt",
59+
"le",
60+
"ge",
61+
"gt",
62+
"add",
63+
"sub",
64+
"div",
65+
"floordiv",
66+
"mul",
67+
"pow",
68+
]
69+
]
70+
param_names = ["opname"]
71+
72+
def setup(self, opname):
73+
arr = np.arange(10 ** 6).reshape(100, -1)
74+
df = DataFrame(arr)
75+
df["C"] = 1.0
76+
self.df = df
77+
self.ser = df[0]
78+
79+
def time_frame_op_with_series_axis0(self, opname):
80+
getattr(self.df, opname)(self.ser, axis=0)
81+
82+
5383
class Ops:
5484

5585
params = [[True, False], ["default", 1]]

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ Performance improvements
185185

186186
- Performance improvement in :class:`Timedelta` constructor (:issue:`30543`)
187187
- Performance improvement in :class:`Timestamp` constructor (:issue:`30543`)
188-
-
188+
- Performance improvement in flex arithmetic ops between :class:`DataFrame` and :class:`Series` with ``axis=0`` (:issue:`31296`)
189189
-
190190

191191
.. ---------------------------------------------------------------------------

pandas/_libs/ops.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def scalar_compare(object[:] values, object val, object op):
100100

101101
@cython.wraparound(False)
102102
@cython.boundscheck(False)
103-
def vec_compare(object[:] left, object[:] right, object op):
103+
def vec_compare(ndarray[object] left, ndarray[object] right, object op):
104104
"""
105105
Compare the elements of `left` with the elements of `right` pointwise,
106106
with the comparison operation described by `op`.

pandas/core/frame.py

-14
Original file line numberDiff line numberDiff line change
@@ -5212,20 +5212,6 @@ def _arith_op(left, right):
52125212

52135213
return new_data
52145214

5215-
def _combine_match_index(self, other: Series, func):
5216-
# at this point we have `self.index.equals(other.index)`
5217-
5218-
if ops.should_series_dispatch(self, other, func):
5219-
# operate column-wise; avoid costly object-casting in `.values`
5220-
new_data = ops.dispatch_to_series(self, other, func)
5221-
else:
5222-
# fastpath --> operate directly on values
5223-
other_vals = other.values.reshape(-1, 1)
5224-
with np.errstate(all="ignore"):
5225-
new_data = func(self.values, other_vals)
5226-
new_data = dispatch_fill_zeros(func, self.values, other_vals, new_data)
5227-
return new_data
5228-
52295215
def _construct_result(self, result) -> "DataFrame":
52305216
"""
52315217
Wrap the result of an arithmetic, comparison, or logical operation.

pandas/core/ops/__init__.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def flex_wrapper(self, other, level=None, fill_value=None, axis=0):
585585
# DataFrame
586586

587587

588-
def _combine_series_frame(left, right, func, axis: int):
588+
def _combine_series_frame(left, right, func, axis: int, str_rep: str):
589589
"""
590590
Apply binary operator `func` to self, other using alignment and fill
591591
conventions determined by the axis argument.
@@ -596,14 +596,25 @@ def _combine_series_frame(left, right, func, axis: int):
596596
right : Series
597597
func : binary operator
598598
axis : {0, 1}
599+
str_rep : str
599600
600601
Returns
601602
-------
602603
result : DataFrame
603604
"""
604605
# We assume that self.align(other, ...) has already been called
605606
if axis == 0:
606-
new_data = left._combine_match_index(right, func)
607+
values = right._values
608+
if isinstance(values, np.ndarray):
609+
# We can operate block-wise
610+
values = values.reshape(-1, 1)
611+
612+
array_op = get_array_op(func, str_rep=str_rep)
613+
bm = left._data.apply(array_op, right=values.T)
614+
return type(left)(bm)
615+
616+
new_data = dispatch_to_series(left, right, func)
617+
607618
else:
608619
new_data = dispatch_to_series(left, right, func, axis="columns")
609620

@@ -791,7 +802,9 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
791802
raise NotImplementedError(f"fill_value {fill_value} not supported.")
792803

793804
axis = self._get_axis_number(axis) if axis is not None else 1
794-
return _combine_series_frame(self, other, pass_op, axis=axis)
805+
return _combine_series_frame(
806+
self, other, pass_op, axis=axis, str_rep=str_rep
807+
)
795808
else:
796809
# in this case we always have `np.ndim(other) == 0`
797810
if fill_value is not None:
@@ -826,7 +839,7 @@ def f(self, other, axis=default_axis, level=None):
826839

827840
elif isinstance(other, ABCSeries):
828841
axis = self._get_axis_number(axis) if axis is not None else 1
829-
return _combine_series_frame(self, other, op, axis=axis)
842+
return _combine_series_frame(self, other, op, axis=axis, str_rep=str_rep)
830843
else:
831844
# in this case we always have `np.ndim(other) == 0`
832845
new_data = dispatch_to_series(self, other, op, str_rep)

pandas/core/ops/array_ops.py

+68-6
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
ABCDatetimeArray,
2929
ABCExtensionArray,
3030
ABCIndex,
31-
ABCIndexClass,
3231
ABCSeries,
3332
ABCTimedeltaArray,
3433
)
@@ -53,13 +52,15 @@ def comp_method_OBJECT_ARRAY(op, x, y):
5352
if isinstance(y, (ABCSeries, ABCIndex)):
5453
y = y.values
5554

56-
result = libops.vec_compare(x.ravel(), y, op)
55+
if x.shape != y.shape:
56+
raise ValueError("Shapes must match", x.shape, y.shape)
57+
result = libops.vec_compare(x.ravel(), y.ravel(), op)
5758
else:
5859
result = libops.scalar_compare(x.ravel(), y, op)
5960
return result.reshape(x.shape)
6061

6162

62-
def masked_arith_op(x, y, op):
63+
def masked_arith_op(x: np.ndarray, y, op):
6364
"""
6465
If the given arithmetic operation fails, attempt it again on
6566
only the non-null elements of the input array(s).
@@ -78,10 +79,22 @@ def masked_arith_op(x, y, op):
7879
dtype = find_common_type([x.dtype, y.dtype])
7980
result = np.empty(x.size, dtype=dtype)
8081

82+
if len(x) != len(y):
83+
if not _can_broadcast(x, y):
84+
raise ValueError(x.shape, y.shape)
85+
86+
# Call notna on pre-broadcasted y for performance
87+
ymask = notna(y)
88+
y = np.broadcast_to(y, x.shape)
89+
ymask = np.broadcast_to(ymask, x.shape)
90+
91+
else:
92+
ymask = notna(y)
93+
8194
# NB: ravel() is only safe since y is ndarray; for e.g. PeriodIndex
8295
# we would get int64 dtype, see GH#19956
8396
yrav = y.ravel()
84-
mask = notna(xrav) & notna(yrav)
97+
mask = notna(xrav) & ymask.ravel()
8598

8699
if yrav.shape != mask.shape:
87100
# FIXME: GH#5284, GH#5035, GH#19448
@@ -211,6 +224,51 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
211224
return res_values
212225

213226

227+
def _broadcast_comparison_op(lvalues, rvalues, op) -> np.ndarray:
228+
"""
229+
Broadcast a comparison operation between two 2D arrays.
230+
231+
Parameters
232+
----------
233+
lvalues : np.ndarray or ExtensionArray
234+
rvalues : np.ndarray or ExtensionArray
235+
236+
Returns
237+
-------
238+
np.ndarray[bool]
239+
"""
240+
if isinstance(rvalues, np.ndarray):
241+
rvalues = np.broadcast_to(rvalues, lvalues.shape)
242+
result = comparison_op(lvalues, rvalues, op)
243+
else:
244+
result = np.empty(lvalues.shape, dtype=bool)
245+
for i in range(len(lvalues)):
246+
result[i, :] = comparison_op(lvalues[i], rvalues[:, 0], op)
247+
return result
248+
249+
250+
def _can_broadcast(lvalues, rvalues) -> bool:
251+
"""
252+
Check if we can broadcast rvalues to match the shape of lvalues.
253+
254+
Parameters
255+
----------
256+
lvalues : np.ndarray or ExtensionArray
257+
rvalues : np.ndarray or ExtensionArray
258+
259+
Returns
260+
-------
261+
bool
262+
"""
263+
# We assume that lengths dont match
264+
if lvalues.ndim == rvalues.ndim == 2:
265+
# See if we can broadcast unambiguously
266+
if lvalues.shape[1] == rvalues.shape[-1]:
267+
if rvalues.shape[0] == 1:
268+
return True
269+
return False
270+
271+
214272
def comparison_op(
215273
left: ArrayLike, right: Any, op, str_rep: Optional[str] = None,
216274
) -> ArrayLike:
@@ -237,12 +295,16 @@ def comparison_op(
237295
# TODO: same for tuples?
238296
rvalues = np.asarray(rvalues)
239297

240-
if isinstance(rvalues, (np.ndarray, ABCExtensionArray, ABCIndexClass)):
298+
if isinstance(rvalues, (np.ndarray, ABCExtensionArray)):
241299
# TODO: make this treatment consistent across ops and classes.
242300
# We are not catching all listlikes here (e.g. frozenset, tuple)
243301
# The ambiguous case is object-dtype. See GH#27803
244302
if len(lvalues) != len(rvalues):
245-
raise ValueError("Lengths must match to compare")
303+
if _can_broadcast(lvalues, rvalues):
304+
return _broadcast_comparison_op(lvalues, rvalues, op)
305+
raise ValueError(
306+
"Lengths must match to compare", lvalues.shape, rvalues.shape
307+
)
246308

247309
if should_extension_dispatch(lvalues, rvalues):
248310
res_values = dispatch_to_extension_op(op, lvalues, rvalues)

pandas/tests/arithmetic/test_array_ops.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
import pandas._testing as tm
7-
from pandas.core.ops.array_ops import na_logical_op
7+
from pandas.core.ops.array_ops import comparison_op, na_logical_op
88

99

1010
def test_na_logical_op_2d():
@@ -19,3 +19,18 @@ def test_na_logical_op_2d():
1919
result = na_logical_op(left, right, operator.or_)
2020
expected = right
2121
tm.assert_numpy_array_equal(result, expected)
22+
23+
24+
def test_object_comparison_2d():
25+
left = np.arange(9).reshape(3, 3).astype(object)
26+
right = left.T
27+
28+
result = comparison_op(left, right, operator.eq)
29+
expected = np.eye(3).astype(bool)
30+
tm.assert_numpy_array_equal(result, expected)
31+
32+
# Ensure that cython doesn't raise on non-writeable arg, which
33+
# we can get from np.broadcast_to
34+
right.flags.writeable = False
35+
result = comparison_op(left, right, operator.ne)
36+
tm.assert_numpy_array_equal(result, ~expected)

pandas/tests/frame/test_arithmetic.py

+19
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,25 @@ def test_floordiv_axis0(self):
348348
result2 = df.floordiv(ser.values, axis=0)
349349
tm.assert_frame_equal(result2, expected)
350350

351+
@pytest.mark.slow
352+
@pytest.mark.parametrize("opname", ["floordiv", "pow"])
353+
def test_floordiv_axis0_numexpr_path(self, opname):
354+
# case that goes through numexpr and has to fall back to masked_arith_op
355+
op = getattr(operator, opname)
356+
357+
arr = np.arange(10 ** 6).reshape(100, -1)
358+
df = pd.DataFrame(arr)
359+
df["C"] = 1.0
360+
361+
ser = df[0]
362+
result = getattr(df, opname)(ser, axis=0)
363+
364+
expected = pd.DataFrame({col: op(df[col], ser) for col in df.columns})
365+
tm.assert_frame_equal(result, expected)
366+
367+
result2 = getattr(df, opname)(ser.values, axis=0)
368+
tm.assert_frame_equal(result2, expected)
369+
351370
def test_df_add_td64_columnwise(self):
352371
# GH 22534 Check that column-wise addition broadcasts correctly
353372
dti = pd.date_range("2016-01-01", periods=10)

0 commit comments

Comments
 (0)