Skip to content

Commit b845cb4

Browse files
jbrockmendelfeefladder
authored andcommitted
PERF: Rolling._apply (pandas-dev#43171)
* PERF: Rolling._apply * consolidate * remove consolidate arg * fix construction * fix warning
1 parent 1d56dae commit b845cb4

File tree

2 files changed

+31
-47
lines changed

2 files changed

+31
-47
lines changed

pandas/core/internals/array_manager.py

-21
Original file line numberDiff line numberDiff line change
@@ -1034,27 +1034,6 @@ def quantile(
10341034
axes = [qs, self._axes[1]]
10351035
return type(self)(new_arrs, axes)
10361036

1037-
def apply_2d(
1038-
self: ArrayManager, f, ignore_failures: bool = False, **kwargs
1039-
) -> ArrayManager:
1040-
"""
1041-
Variant of `apply`, but where the function should not be applied to
1042-
each column independently, but to the full data as a 2D array.
1043-
"""
1044-
values = self.as_array()
1045-
try:
1046-
result = f(values, **kwargs)
1047-
except (TypeError, NotImplementedError):
1048-
if not ignore_failures:
1049-
raise
1050-
result_arrays = []
1051-
new_axes = [self._axes[0], self.axes[1].take([])]
1052-
else:
1053-
result_arrays = [result[:, i] for i in range(len(self._axes[1]))]
1054-
new_axes = self._axes
1055-
1056-
return type(self)(result_arrays, new_axes)
1057-
10581037
# ----------------------------------------------------------------
10591038

10601039
def unstack(self, unstacker, fill_value) -> ArrayManager:

pandas/core/window/rolling.py

+31-26
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
PeriodIndex,
6969
TimedeltaIndex,
7070
)
71-
from pandas.core.internals import ArrayManager
7271
from pandas.core.reshape.concat import concat
7372
from pandas.core.util.numba_ import (
7473
NUMBA_FUNC_CACHE,
@@ -421,26 +420,39 @@ def _apply_blockwise(
421420
# GH 12541: Special case for count where we support date-like types
422421
obj = notna(obj).astype(int)
423422
obj._mgr = obj._mgr.consolidate()
424-
mgr = obj._mgr
425423

426-
def hfunc(bvalues: ArrayLike) -> ArrayLike:
427-
# TODO(EA2D): getattr unnecessary with 2D EAs
428-
values = self._prep_values(getattr(bvalues, "T", bvalues))
429-
res_values = homogeneous_func(values)
430-
return getattr(res_values, "T", res_values)
431-
432-
def hfunc2d(values: ArrayLike) -> ArrayLike:
424+
def hfunc(values: ArrayLike) -> ArrayLike:
433425
values = self._prep_values(values)
434426
return homogeneous_func(values)
435427

436-
if isinstance(mgr, ArrayManager) and self.axis == 1:
437-
new_mgr = mgr.apply_2d(hfunc2d, ignore_failures=True)
438-
else:
439-
new_mgr = mgr.apply(hfunc, ignore_failures=True)
428+
if self.axis == 1:
429+
obj = obj.T
440430

441-
if 0 != len(new_mgr.items) != len(mgr.items):
431+
taker = []
432+
res_values = []
433+
for i, arr in enumerate(obj._iter_column_arrays()):
434+
# GH#42736 operate column-wise instead of block-wise
435+
try:
436+
res = hfunc(arr)
437+
except (TypeError, NotImplementedError):
438+
pass
439+
else:
440+
res_values.append(res)
441+
taker.append(i)
442+
443+
df = type(obj)._from_arrays(
444+
res_values,
445+
index=obj.index,
446+
columns=obj.columns.take(taker),
447+
verify_integrity=False,
448+
)
449+
450+
if self.axis == 1:
451+
df = df.T
452+
453+
if 0 != len(res_values) != len(obj.columns):
442454
# GH#42738 ignore_failures dropped nuisance columns
443-
dropped = mgr.items.difference(new_mgr.items)
455+
dropped = obj.columns.difference(obj.columns.take(taker))
444456
warnings.warn(
445457
"Dropping of nuisance columns in rolling operations "
446458
"is deprecated; in a future version this will raise TypeError. "
@@ -449,9 +461,8 @@ def hfunc2d(values: ArrayLike) -> ArrayLike:
449461
FutureWarning,
450462
stacklevel=find_stack_level(),
451463
)
452-
out = obj._constructor(new_mgr)
453464

454-
return self._resolve_output(out, obj)
465+
return self._resolve_output(df, obj)
455466

456467
def _apply_tablewise(
457468
self, homogeneous_func: Callable[..., ArrayLike], name: str | None = None
@@ -540,10 +551,7 @@ def calc(x):
540551
return func(x, start, end, min_periods, *numba_args)
541552

542553
with np.errstate(all="ignore"):
543-
if values.ndim > 1 and self.method == "single":
544-
result = np.apply_along_axis(calc, self.axis, values)
545-
else:
546-
result = calc(values)
554+
result = calc(values)
547555

548556
if numba_cache_key is not None:
549557
NUMBA_FUNC_CACHE[numba_cache_key] = func
@@ -1024,11 +1032,8 @@ def calc(x):
10241032
return func(x, window, self.min_periods or len(window))
10251033

10261034
with np.errstate(all="ignore"):
1027-
if values.ndim > 1:
1028-
result = np.apply_along_axis(calc, self.axis, values)
1029-
else:
1030-
# Our weighted aggregations return memoryviews
1031-
result = np.asarray(calc(values))
1035+
# Our weighted aggregations return memoryviews
1036+
result = np.asarray(calc(values))
10321037

10331038
if self.center:
10341039
result = self._center_window(result, offset)

0 commit comments

Comments
 (0)