Skip to content

Commit b9ad20a

Browse files
authored
PERF: block-wise arithmetic for frame-with-frame (#32779)
1 parent 0babe10 commit b9ad20a

File tree

13 files changed

+248
-33
lines changed

13 files changed

+248
-33
lines changed

asv_bench/benchmarks/arithmetic.py

+53
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,59 @@ def time_frame_op_with_series_axis1(self, opname):
101101
getattr(operator, opname)(self.df, self.ser)
102102

103103

104+
class FrameWithFrameWide:
105+
# Many-columns, mixed dtypes
106+
107+
params = [
108+
[
109+
# GH#32779 has discussion of which operators are included here
110+
operator.add,
111+
operator.floordiv,
112+
operator.gt,
113+
]
114+
]
115+
param_names = ["op"]
116+
117+
def setup(self, op):
118+
# we choose dtypes so as to make the blocks
119+
# a) not perfectly match between right and left
120+
# b) appreciably bigger than single columns
121+
n_cols = 2000
122+
n_rows = 500
123+
124+
# construct dataframe with 2 blocks
125+
arr1 = np.random.randn(n_rows, int(n_cols / 2)).astype("f8")
126+
arr2 = np.random.randn(n_rows, int(n_cols / 2)).astype("f4")
127+
df = pd.concat(
128+
[pd.DataFrame(arr1), pd.DataFrame(arr2)], axis=1, ignore_index=True,
129+
)
130+
# should already be the case, but just to be sure
131+
df._consolidate_inplace()
132+
133+
# TODO: GH#33198 the setting here shoudlnt need two steps
134+
arr1 = np.random.randn(n_rows, int(n_cols / 4)).astype("f8")
135+
arr2 = np.random.randn(n_rows, int(n_cols / 2)).astype("i8")
136+
arr3 = np.random.randn(n_rows, int(n_cols / 4)).astype("f8")
137+
df2 = pd.concat(
138+
[pd.DataFrame(arr1), pd.DataFrame(arr2), pd.DataFrame(arr3)],
139+
axis=1,
140+
ignore_index=True,
141+
)
142+
# should already be the case, but just to be sure
143+
df2._consolidate_inplace()
144+
145+
self.left = df
146+
self.right = df2
147+
148+
def time_op_different_blocks(self, op):
149+
# blocks (and dtypes) are not aligned
150+
op(self.left, self.right)
151+
152+
def time_op_same_blocks(self, op):
153+
# blocks (and dtypes) are aligned
154+
op(self.left, self.left)
155+
156+
104157
class Ops:
105158

106159
params = [[True, False], ["default", 1]]

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ Performance improvements
611611
and :meth:`~pandas.core.groupby.groupby.Groupby.last` (:issue:`34178`)
612612
- Performance improvement in :func:`factorize` for nullable (integer and boolean) dtypes (:issue:`33064`).
613613
- Performance improvement in reductions (sum, prod, min, max) for nullable (integer and boolean) dtypes (:issue:`30982`, :issue:`33261`, :issue:`33442`).
614-
614+
- Performance improvement in arithmetic operations between two :class:`DataFrame` objects (:issue:`32779`)
615615

616616
.. ---------------------------------------------------------------------------
617617

pandas/_libs/internals.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ cdef class BlockPlacement:
4949
else:
5050
# Cython memoryview interface requires ndarray to be writeable.
5151
arr = np.require(val, dtype=np.int64, requirements='W')
52-
assert arr.ndim == 1
52+
assert arr.ndim == 1, arr.shape
5353
self._as_array = arr
5454
self._has_array = True
5555

pandas/core/arrays/datetimelike.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def _validate_comparison_value(self, other):
9898

9999
@unpack_zerodim_and_defer(opname)
100100
def wrapper(self, other):
101+
if self.ndim > 1 and getattr(other, "shape", None) == self.shape:
102+
# TODO: handle 2D-like listlikes
103+
return op(self.ravel(), other.ravel()).reshape(self.shape)
104+
101105
try:
102106
other = _validate_comparison_value(self, other)
103107
except InvalidComparison:
@@ -1308,18 +1312,20 @@ def _addsub_object_array(self, other: np.ndarray, op):
13081312
"""
13091313
assert op in [operator.add, operator.sub]
13101314
if len(other) == 1:
1315+
# If both 1D then broadcasting is unambiguous
1316+
# TODO(EA2D): require self.ndim == other.ndim here
13111317
return op(self, other[0])
13121318

13131319
warnings.warn(
1314-
"Adding/subtracting array of DateOffsets to "
1320+
"Adding/subtracting object-dtype array to "
13151321
f"{type(self).__name__} not vectorized",
13161322
PerformanceWarning,
13171323
)
13181324

13191325
# Caller is responsible for broadcasting if necessary
13201326
assert self.shape == other.shape, (self.shape, other.shape)
13211327

1322-
res_values = op(self.astype("O"), np.array(other))
1328+
res_values = op(self.astype("O"), np.asarray(other))
13231329
result = array(res_values.ravel())
13241330
result = extract_array(result, extract_numpy=True).reshape(self.shape)
13251331
return result

pandas/core/frame.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def __init__(
455455
mgr = self._init_mgr(
456456
data, axes=dict(index=index, columns=columns), dtype=dtype, copy=copy
457457
)
458+
458459
elif isinstance(data, dict):
459460
mgr = init_dict(data, index, columns, dtype=dtype)
460461
elif isinstance(data, ma.MaskedArray):
@@ -5754,10 +5755,11 @@ def _construct_result(self, result) -> "DataFrame":
57545755
-------
57555756
DataFrame
57565757
"""
5757-
out = self._constructor(result, index=self.index, copy=False)
5758+
out = self._constructor(result, copy=False)
57585759
# Pin columns instead of passing to constructor for compat with
57595760
# non-unique columns case
57605761
out.columns = self.columns
5762+
out.index = self.index
57615763
return out
57625764

57635765
def combine(

pandas/core/internals/managers.py

+47-14
Original file line numberDiff line numberDiff line change
@@ -1269,12 +1269,22 @@ def reindex_indexer(
12691269

12701270
return type(self).from_blocks(new_blocks, new_axes)
12711271

1272-
def _slice_take_blocks_ax0(self, slice_or_indexer, fill_value=lib.no_default):
1272+
def _slice_take_blocks_ax0(
1273+
self, slice_or_indexer, fill_value=lib.no_default, only_slice: bool = False
1274+
):
12731275
"""
12741276
Slice/take blocks along axis=0.
12751277
12761278
Overloaded for SingleBlock
12771279
1280+
Parameters
1281+
----------
1282+
slice_or_indexer : slice, ndarray[bool], or list-like of ints
1283+
fill_value : scalar, default lib.no_default
1284+
only_slice : bool, default False
1285+
If True, we always return views on existing arrays, never copies.
1286+
This is used when called from ops.blockwise.operate_blockwise.
1287+
12781288
Returns
12791289
-------
12801290
new_blocks : list of Block
@@ -1298,14 +1308,23 @@ def _slice_take_blocks_ax0(self, slice_or_indexer, fill_value=lib.no_default):
12981308
if allow_fill and fill_value is None:
12991309
_, fill_value = maybe_promote(blk.dtype)
13001310

1301-
return [
1302-
blk.take_nd(
1303-
slobj,
1304-
axis=0,
1305-
new_mgr_locs=slice(0, sllen),
1306-
fill_value=fill_value,
1307-
)
1308-
]
1311+
if not allow_fill and only_slice:
1312+
# GH#33597 slice instead of take, so we get
1313+
# views instead of copies
1314+
blocks = [
1315+
blk.getitem_block([ml], new_mgr_locs=i)
1316+
for i, ml in enumerate(slobj)
1317+
]
1318+
return blocks
1319+
else:
1320+
return [
1321+
blk.take_nd(
1322+
slobj,
1323+
axis=0,
1324+
new_mgr_locs=slice(0, sllen),
1325+
fill_value=fill_value,
1326+
)
1327+
]
13091328

13101329
if sl_type in ("slice", "mask"):
13111330
blknos = self.blknos[slobj]
@@ -1342,11 +1361,25 @@ def _slice_take_blocks_ax0(self, slice_or_indexer, fill_value=lib.no_default):
13421361
blocks.append(newblk)
13431362

13441363
else:
1345-
blocks.append(
1346-
blk.take_nd(
1347-
blklocs[mgr_locs.indexer], axis=0, new_mgr_locs=mgr_locs,
1348-
)
1349-
)
1364+
# GH#32779 to avoid the performance penalty of copying,
1365+
# we may try to only slice
1366+
taker = blklocs[mgr_locs.indexer]
1367+
max_len = max(len(mgr_locs), taker.max() + 1)
1368+
if only_slice:
1369+
taker = lib.maybe_indices_to_slice(taker, max_len)
1370+
1371+
if isinstance(taker, slice):
1372+
nb = blk.getitem_block(taker, new_mgr_locs=mgr_locs)
1373+
blocks.append(nb)
1374+
elif only_slice:
1375+
# GH#33597 slice instead of take, so we get
1376+
# views instead of copies
1377+
for i, ml in zip(taker, mgr_locs):
1378+
nb = blk.getitem_block([i], new_mgr_locs=ml)
1379+
blocks.append(nb)
1380+
else:
1381+
nb = blk.take_nd(taker, axis=0, new_mgr_locs=mgr_locs)
1382+
blocks.append(nb)
13501383

13511384
return blocks
13521385

pandas/core/ops/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
logical_op,
2727
)
2828
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY # noqa:F401
29+
from pandas.core.ops.blockwise import operate_blockwise
2930
from pandas.core.ops.common import unpack_zerodim_and_defer
3031
from pandas.core.ops.dispatch import should_series_dispatch
3132
from pandas.core.ops.docstrings import (
@@ -325,8 +326,9 @@ def dispatch_to_series(left, right, func, str_rep=None, axis=None):
325326
elif isinstance(right, ABCDataFrame):
326327
assert right._indexed_same(left)
327328

328-
def column_op(a, b):
329-
return {i: func(a.iloc[:, i], b.iloc[:, i]) for i in range(len(a.columns))}
329+
array_op = get_array_op(func, str_rep=str_rep)
330+
bm = operate_blockwise(left, right, array_op)
331+
return type(left)(bm)
330332

331333
elif isinstance(right, ABCSeries) and axis == "columns":
332334
# We only get here if called via _combine_series_frame,

pandas/core/ops/array_ops.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import partial
77
import operator
88
from typing import Any, Optional, Tuple
9+
import warnings
910

1011
import numpy as np
1112

@@ -120,7 +121,7 @@ def masked_arith_op(x: np.ndarray, y, op):
120121
return result
121122

122123

123-
def define_na_arithmetic_op(op, str_rep: str):
124+
def define_na_arithmetic_op(op, str_rep: Optional[str]):
124125
def na_op(x, y):
125126
return na_arithmetic_op(x, y, op, str_rep)
126127

@@ -191,7 +192,8 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
191192
# NB: We assume that extract_array has already been called
192193
# on `left` and `right`.
193194
lvalues = maybe_upcast_datetimelike_array(left)
194-
rvalues = maybe_upcast_for_op(right, lvalues.shape)
195+
rvalues = maybe_upcast_datetimelike_array(right)
196+
rvalues = maybe_upcast_for_op(rvalues, lvalues.shape)
195197

196198
if should_extension_dispatch(lvalues, rvalues) or isinstance(rvalues, Timedelta):
197199
# Timedelta is included because numexpr will fail on it, see GH#31457
@@ -254,8 +256,13 @@ def comparison_op(
254256
res_values = comp_method_OBJECT_ARRAY(op, lvalues, rvalues)
255257

256258
else:
257-
with np.errstate(all="ignore"):
258-
res_values = na_arithmetic_op(lvalues, rvalues, op, str_rep, is_cmp=True)
259+
with warnings.catch_warnings():
260+
# suppress warnings from numpy about element-wise comparison
261+
warnings.simplefilter("ignore", DeprecationWarning)
262+
with np.errstate(all="ignore"):
263+
res_values = na_arithmetic_op(
264+
lvalues, rvalues, op, str_rep, is_cmp=True
265+
)
259266

260267
return res_values
261268

pandas/core/ops/blockwise.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from typing import TYPE_CHECKING, List, Tuple
2+
3+
import numpy as np
4+
5+
from pandas._typing import ArrayLike
6+
7+
if TYPE_CHECKING:
8+
from pandas.core.internals.blocks import Block # noqa:F401
9+
10+
11+
def operate_blockwise(left, right, array_op):
12+
# At this point we have already checked
13+
# assert right._indexed_same(left)
14+
15+
res_blks: List["Block"] = []
16+
rmgr = right._mgr
17+
for n, blk in enumerate(left._mgr.blocks):
18+
locs = blk.mgr_locs
19+
blk_vals = blk.values
20+
21+
left_ea = not isinstance(blk_vals, np.ndarray)
22+
23+
rblks = rmgr._slice_take_blocks_ax0(locs.indexer, only_slice=True)
24+
25+
# Assertions are disabled for performance, but should hold:
26+
# if left_ea:
27+
# assert len(locs) == 1, locs
28+
# assert len(rblks) == 1, rblks
29+
# assert rblks[0].shape[0] == 1, rblks[0].shape
30+
31+
for k, rblk in enumerate(rblks):
32+
right_ea = not isinstance(rblk.values, np.ndarray)
33+
34+
lvals, rvals = _get_same_shape_values(blk, rblk, left_ea, right_ea)
35+
36+
res_values = array_op(lvals, rvals)
37+
if left_ea and not right_ea and hasattr(res_values, "reshape"):
38+
res_values = res_values.reshape(1, -1)
39+
nbs = rblk._split_op_result(res_values)
40+
41+
# Assertions are disabled for performance, but should hold:
42+
# if right_ea or left_ea:
43+
# assert len(nbs) == 1
44+
# else:
45+
# assert res_values.shape == lvals.shape, (res_values.shape, lvals.shape)
46+
47+
_reset_block_mgr_locs(nbs, locs)
48+
49+
res_blks.extend(nbs)
50+
51+
# Assertions are disabled for performance, but should hold:
52+
# slocs = {y for nb in res_blks for y in nb.mgr_locs.as_array}
53+
# nlocs = sum(len(nb.mgr_locs.as_array) for nb in res_blks)
54+
# assert nlocs == len(left.columns), (nlocs, len(left.columns))
55+
# assert len(slocs) == nlocs, (len(slocs), nlocs)
56+
# assert slocs == set(range(nlocs)), slocs
57+
58+
new_mgr = type(rmgr)(res_blks, axes=rmgr.axes, do_integrity_check=False)
59+
return new_mgr
60+
61+
62+
def _reset_block_mgr_locs(nbs: List["Block"], locs):
63+
"""
64+
Reset mgr_locs to correspond to our original DataFrame.
65+
"""
66+
for nb in nbs:
67+
nblocs = locs.as_array[nb.mgr_locs.indexer]
68+
nb.mgr_locs = nblocs
69+
# Assertions are disabled for performance, but should hold:
70+
# assert len(nblocs) == nb.shape[0], (len(nblocs), nb.shape)
71+
# assert all(x in locs.as_array for x in nb.mgr_locs.as_array)
72+
73+
74+
def _get_same_shape_values(
75+
lblk: "Block", rblk: "Block", left_ea: bool, right_ea: bool
76+
) -> Tuple[ArrayLike, ArrayLike]:
77+
"""
78+
Slice lblk.values to align with rblk. Squeeze if we have EAs.
79+
"""
80+
lvals = lblk.values
81+
rvals = rblk.values
82+
83+
# Require that the indexing into lvals be slice-like
84+
assert rblk.mgr_locs.is_slice_like, rblk.mgr_locs
85+
86+
# TODO(EA2D): with 2D EAs pnly this first clause would be needed
87+
if not (left_ea or right_ea):
88+
lvals = lvals[rblk.mgr_locs.indexer, :]
89+
assert lvals.shape == rvals.shape, (lvals.shape, rvals.shape)
90+
elif left_ea and right_ea:
91+
assert lvals.shape == rvals.shape, (lvals.shape, rvals.shape)
92+
elif right_ea:
93+
# lvals are 2D, rvals are 1D
94+
lvals = lvals[rblk.mgr_locs.indexer, :]
95+
assert lvals.shape[0] == 1, lvals.shape
96+
lvals = lvals[0, :]
97+
else:
98+
# lvals are 1D, rvals are 2D
99+
assert rvals.shape[0] == 1, rvals.shape
100+
rvals = rvals[0, :]
101+
102+
return lvals, rvals

0 commit comments

Comments
 (0)