Skip to content

Commit b66ba5a

Browse files
committed
BUG: Respect axis when doing DataFrame.expanding
Closes gh-23372.
1 parent 360e727 commit b66ba5a

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,7 @@ Groupby/Resample/Rolling
12131213
- Bug in :meth:`SeriesGroupBy.mean` when values were integral but could not fit inside of int64, overflowing instead. (:issue:`22487`)
12141214
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)
12151215
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
1216+
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
12161217

12171218
Reshaping
12181219
^^^^^^^^^

pandas/core/window.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1866,12 +1866,11 @@ def _constructor(self):
18661866
return Expanding
18671867

18681868
def _get_window(self, other=None):
1869-
obj = self._selected_obj
1870-
if other is None:
1871-
return (max(len(obj), self.min_periods) if self.min_periods
1872-
else len(obj))
1873-
return (max((len(obj) + len(obj)), self.min_periods)
1874-
if self.min_periods else (len(obj) + len(obj)))
1869+
axis = self.obj._get_axis(self.axis)
1870+
length = len(axis) + (other is not None) * len(axis)
1871+
1872+
other = self.min_periods or -1
1873+
return max(length, other)
18751874

18761875
_agg_doc = dedent("""
18771876
Examples

pandas/tests/test_window.py

+32
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,22 @@ def test_iter_raises(self, klass):
627627
with pytest.raises(NotImplementedError):
628628
iter(obj.rolling(2))
629629

630+
@pytest.mark.parametrize("axis,expected", [
631+
(0, DataFrame({
632+
i: [np.nan] * 2 + [3.0] * 8
633+
for i in range(20)
634+
})),
635+
(1, DataFrame([
636+
[np.nan] * 2 + [3.0] * 18
637+
] * 10))
638+
])
639+
def test_rolling_axis(self, axis, expected):
640+
# see gh-23372.
641+
df = DataFrame(np.ones((10, 20)))
642+
643+
result = df.rolling(3, axis=axis).sum()
644+
tm.assert_frame_equal(result, expected)
645+
630646

631647
class TestExpanding(Base):
632648

@@ -714,6 +730,22 @@ def test_iter_raises(self, klass):
714730
with pytest.raises(NotImplementedError):
715731
iter(obj.expanding(2))
716732

733+
@pytest.mark.parametrize("axis,expected", [
734+
(0, DataFrame({
735+
i: [np.nan] * 2 + [float(j) for j in range(3, 11)]
736+
for i in range(20)
737+
})),
738+
(1, DataFrame([
739+
[np.nan] * 2 + [float(i) for i in range(3, 21)]
740+
] * 10))
741+
])
742+
def test_expanding_axis(self, axis, expected):
743+
# see gh-23372.
744+
df = DataFrame(np.ones((10, 20)))
745+
746+
result = df.expanding(3, axis=axis).sum()
747+
tm.assert_frame_equal(result, expected)
748+
717749

718750
class TestEWM(Base):
719751

0 commit comments

Comments
 (0)