Skip to content

Commit b28bc34

Browse files
gfyoungPingviinituutti
authored andcommitted
BUG: Respect axis when doing DataFrame.expanding (pandas-dev#23402)
Closes pandas-devgh-23372.
1 parent 01b4de9 commit b28bc34

File tree

3 files changed

+58
-6
lines changed

3 files changed

+58
-6
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,7 @@ Groupby/Resample/Rolling
12381238
- Bug in :meth:`SeriesGroupBy.mean` when values were integral but could not fit inside of int64, overflowing instead. (:issue:`22487`)
12391239
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)
12401240
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
1241+
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
12411242

12421243
Reshaping
12431244
^^^^^^^^^

pandas/core/window.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -1866,12 +1866,25 @@ def _constructor(self):
18661866
return Expanding
18671867

18681868
def _get_window(self, other=None):
1869-
obj = self._selected_obj
1870-
if other is None:
1871-
return (max(len(obj), self.min_periods) if self.min_periods
1872-
else len(obj))
1873-
return (max((len(obj) + len(obj)), self.min_periods)
1874-
if self.min_periods else (len(obj) + len(obj)))
1869+
"""
1870+
Get the window length over which to perform some operation.
1871+
1872+
Parameters
1873+
----------
1874+
other : object, default None
1875+
The other object that is involved in the operation.
1876+
Such an object is involved for operations like covariance.
1877+
1878+
Returns
1879+
-------
1880+
window : int
1881+
The window length.
1882+
"""
1883+
axis = self.obj._get_axis(self.axis)
1884+
length = len(axis) + (other is not None) * len(axis)
1885+
1886+
other = self.min_periods or -1
1887+
return max(length, other)
18751888

18761889
_agg_doc = dedent("""
18771890
Examples

pandas/tests/test_window.py

+38
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,25 @@ def test_iter_raises(self, klass):
627627
with pytest.raises(NotImplementedError):
628628
iter(obj.rolling(2))
629629

630+
def test_rolling_axis(self, axis_frame):
631+
# see gh-23372.
632+
df = DataFrame(np.ones((10, 20)))
633+
axis = df._get_axis_number(axis_frame)
634+
635+
if axis == 0:
636+
expected = DataFrame({
637+
i: [np.nan] * 2 + [3.0] * 8
638+
for i in range(20)
639+
})
640+
else:
641+
# axis == 1
642+
expected = DataFrame([
643+
[np.nan] * 2 + [3.0] * 18
644+
] * 10)
645+
646+
result = df.rolling(3, axis=axis_frame).sum()
647+
tm.assert_frame_equal(result, expected)
648+
630649

631650
class TestExpanding(Base):
632651

@@ -714,6 +733,25 @@ def test_iter_raises(self, klass):
714733
with pytest.raises(NotImplementedError):
715734
iter(obj.expanding(2))
716735

736+
def test_expanding_axis(self, axis_frame):
737+
# see gh-23372.
738+
df = DataFrame(np.ones((10, 20)))
739+
axis = df._get_axis_number(axis_frame)
740+
741+
if axis == 0:
742+
expected = DataFrame({
743+
i: [np.nan] * 2 + [float(j) for j in range(3, 11)]
744+
for i in range(20)
745+
})
746+
else:
747+
# axis == 1
748+
expected = DataFrame([
749+
[np.nan] * 2 + [float(i) for i in range(3, 21)]
750+
] * 10)
751+
752+
result = df.expanding(3, axis=axis_frame).sum()
753+
tm.assert_frame_equal(result, expected)
754+
717755

718756
class TestEWM(Base):
719757

0 commit comments

Comments
 (0)