Skip to content

Commit 04a0b86

Browse files
[ArrayManager] Fix window operations with axis=1 (#40251)
1 parent 14c8892 commit 04a0b86

File tree

5 files changed

+41
-6
lines changed

5 files changed

+41
-6
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,4 @@ jobs:
192192
pytest pandas/tests/tseries/
193193
pytest pandas/tests/tslibs/
194194
pytest pandas/tests/util/
195+
pytest pandas/tests/window/

pandas/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def add_imports(doctest_namespace):
190190
# ----------------------------------------------------------------
191191
# Common arguments
192192
# ----------------------------------------------------------------
193-
@pytest.fixture(params=[0, 1, "index", "columns"], ids=lambda x: f"axis {repr(x)}")
193+
@pytest.fixture(params=[0, 1, "index", "columns"], ids=lambda x: f"axis={repr(x)}")
194194
def axis(request):
195195
"""
196196
Fixture for returning the axis numbers of a DataFrame.

pandas/core/internals/array_manager.py

+24
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,30 @@ def apply(
406406

407407
return type(self)(result_arrays, new_axes)
408408

409+
def apply_2d(
410+
self: T,
411+
f,
412+
ignore_failures: bool = False,
413+
**kwargs,
414+
) -> T:
415+
"""
416+
Variant of `apply`, but where the function should not be applied to
417+
each column independently, but to the full data as a 2D array.
418+
"""
419+
values = self.as_array()
420+
try:
421+
result = f(values, **kwargs)
422+
except (TypeError, NotImplementedError):
423+
if not ignore_failures:
424+
raise
425+
result_arrays = []
426+
new_axes = [self._axes[0], self.axes[1].take([])]
427+
else:
428+
result_arrays = [result[:, i] for i in range(len(self._axes[1]))]
429+
new_axes = self._axes
430+
431+
return type(self)(result_arrays, new_axes)
432+
409433
def apply_with_block(self: T, f, align_keys=None, **kwargs) -> T:
410434

411435
align_keys = align_keys or []

pandas/core/window/rolling.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
Index,
7070
MultiIndex,
7171
)
72+
from pandas.core.internals import ArrayManager
7273
from pandas.core.reshape.concat import concat
7374
from pandas.core.util.numba_ import (
7475
NUMBA_FUNC_CACHE,
@@ -410,7 +411,14 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
410411
res_values = homogeneous_func(values)
411412
return getattr(res_values, "T", res_values)
412413

413-
new_mgr = mgr.apply(hfunc, ignore_failures=True)
414+
def hfunc2d(values: ArrayLike) -> ArrayLike:
415+
values = self._prep_values(values)
416+
return homogeneous_func(values)
417+
418+
if isinstance(mgr, ArrayManager) and self.axis == 1:
419+
new_mgr = mgr.apply_2d(hfunc2d, ignore_failures=True)
420+
else:
421+
new_mgr = mgr.apply(hfunc, ignore_failures=True)
414422
out = obj._constructor(new_mgr)
415423

416424
if out.shape[1] == 0 and obj.shape[1] > 0:

pandas/tests/window/test_rolling.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def test_rolling_datetime(axis_frame, tz_naive_fixture):
397397
tm.assert_frame_equal(result, expected)
398398

399399

400-
def test_rolling_window_as_string():
400+
def test_rolling_window_as_string(using_array_manager):
401401
# see gh-22590
402402
date_today = datetime.now()
403403
days = date_range(date_today, date_today + timedelta(365), freq="D")
@@ -450,9 +450,11 @@ def test_rolling_window_as_string():
450450
+ [95.0] * 20
451451
)
452452

453-
expected = Series(
454-
expData, index=days.rename("DateCol")._with_freq(None), name="metric"
455-
)
453+
index = days.rename("DateCol")
454+
if not using_array_manager:
455+
# INFO(ArrayManager) preserves the frequence of the index
456+
index = index._with_freq(None)
457+
expected = Series(expData, index=index, name="metric")
456458
tm.assert_series_equal(result, expected)
457459

458460

0 commit comments

Comments
 (0)