Skip to content

Commit 00a510b

Browse files
authored
[BUG]: Rolling.sum() calculated wrong values when axis is one and dtypes are mixed (#36458)
1 parent a22cf43 commit 00a510b

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

doc/source/whatsnew/v1.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ Groupby/resample/rolling
344344
- Bug in :meth:`DataFrameGroupby.tshift` failing to raise ``ValueError`` when a frequency cannot be inferred for the index of a group (:issue:`35937`)
345345
- Bug in :meth:`DataFrame.groupby` does not always maintain column index name for ``any``, ``all``, ``bfill``, ``ffill``, ``shift`` (:issue:`29764`)
346346
- Bug in :meth:`DataFrameGroupBy.apply` raising error with ``np.nan`` group(s) when ``dropna=False`` (:issue:`35889`)
347-
-
347+
- Bug in :meth:`Rolling.sum()` returned wrong values when dtypes where mixed between float and integer and axis was equal to one (:issue:`20649`, :issue:`35596`)
348348

349349
Reshaping
350350
^^^^^^^^^

pandas/core/window/rolling.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,13 @@ def _create_data(self, obj: FrameOrSeries) -> FrameOrSeries:
243243
if self.on is not None and not isinstance(self.on, Index):
244244
if obj.ndim == 2:
245245
obj = obj.reindex(columns=obj.columns.difference([self.on]), copy=False)
246-
246+
if self.axis == 1:
247+
# GH: 20649 in case of mixed dtype and axis=1 we have to convert everything
248+
# to float to calculate the complete row at once. We exclude all non-numeric
249+
# dtypes.
250+
obj = obj.select_dtypes(include=["integer", "float"], exclude=["timedelta"])
251+
obj = obj.astype("float64", copy=False)
252+
obj._mgr = obj._mgr.consolidate()
247253
return obj
248254

249255
def _gotitem(self, key, ndim, subset=None):

pandas/tests/window/test_rolling.py

+48
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,51 @@ def test_rolling_numerical_too_large_numbers():
771771
index=dates,
772772
)
773773
tm.assert_series_equal(result, expected)
774+
775+
776+
@pytest.mark.parametrize(
777+
("func", "value"),
778+
[("sum", 2.0), ("max", 1.0), ("min", 1.0), ("mean", 1.0), ("median", 1.0)],
779+
)
780+
def test_rolling_mixed_dtypes_axis_1(func, value):
781+
# GH: 20649
782+
df = pd.DataFrame(1, index=[1, 2], columns=["a", "b", "c"])
783+
df["c"] = 1.0
784+
result = getattr(df.rolling(window=2, min_periods=1, axis=1), func)()
785+
expected = pd.DataFrame(
786+
{"a": [1.0, 1.0], "b": [value, value], "c": [value, value]}, index=[1, 2]
787+
)
788+
tm.assert_frame_equal(result, expected)
789+
790+
791+
def test_rolling_axis_one_with_nan():
792+
# GH: 35596
793+
df = pd.DataFrame(
794+
[
795+
[0, 1, 2, 4, np.nan, np.nan, np.nan],
796+
[0, 1, 2, np.nan, np.nan, np.nan, np.nan],
797+
[0, 2, 2, np.nan, 2, np.nan, 1],
798+
]
799+
)
800+
result = df.rolling(window=7, min_periods=1, axis="columns").sum()
801+
expected = pd.DataFrame(
802+
[
803+
[0.0, 1.0, 3.0, 7.0, 7.0, 7.0, 7.0],
804+
[0.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0],
805+
[0.0, 2.0, 4.0, 4.0, 6.0, 6.0, 7.0],
806+
]
807+
)
808+
tm.assert_frame_equal(result, expected)
809+
810+
811+
@pytest.mark.parametrize(
812+
"value",
813+
["test", pd.to_datetime("2019-12-31"), pd.to_timedelta("1 days 06:05:01.00003")],
814+
)
815+
def test_rolling_axis_1_non_numeric_dtypes(value):
816+
# GH: 20649
817+
df = pd.DataFrame({"a": [1, 2]})
818+
df["b"] = value
819+
result = df.rolling(window=2, min_periods=1, axis=1).sum()
820+
expected = pd.DataFrame({"a": [1.0, 2.0]})
821+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)