Skip to content

Commit 55b969b

Browse files
authored
BUG: rolling with datetime ArrowDtype (#56370)
* BUG: rolling with datetime ArrowDtype * Dont modify needs_i8_conversion * More explicit tests * Fix arrow to_numpy
1 parent 58ba000 commit 55b969b

File tree

4 files changed

+37
-10
lines changed

4 files changed

+37
-10
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ Groupby/resample/rolling
865865
- Bug in :meth:`DataFrame.resample` when resampling on a :class:`ArrowDtype` of ``pyarrow.timestamp`` or ``pyarrow.duration`` type (:issue:`55989`)
866866
- Bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.BusinessDay` (:issue:`55281`)
867867
- Bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.MonthBegin` (:issue:`55271`)
868+
- Bug in :meth:`DataFrame.rolling` and :meth:`Series.rolling` where either the ``index`` or ``on`` column was :class:`ArrowDtype` with ``pyarrow.timestamp`` type (:issue:`55849`)
868869

869870
Reshaping
870871
^^^^^^^^^

pandas/core/arrays/datetimelike.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
pandas_dtype,
9393
)
9494
from pandas.core.dtypes.dtypes import (
95+
ArrowDtype,
9596
CategoricalDtype,
9697
DatetimeTZDtype,
9798
ExtensionDtype,
@@ -2531,7 +2532,7 @@ def _validate_inferred_freq(
25312532
return freq
25322533

25332534

2534-
def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str:
2535+
def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype | ArrowDtype) -> str:
25352536
"""
25362537
Return the unit str corresponding to the dtype's resolution.
25372538
@@ -2546,4 +2547,8 @@ def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str:
25462547
"""
25472548
if isinstance(dtype, DatetimeTZDtype):
25482549
return dtype.unit
2550+
elif isinstance(dtype, ArrowDtype):
2551+
if dtype.kind not in "mM":
2552+
raise ValueError(f"{dtype=} does not have a resolution.")
2553+
return dtype.pyarrow_dtype.unit
25492554
return np.datetime_data(dtype)[0]

pandas/core/window/rolling.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Any,
1515
Callable,
1616
Literal,
17-
cast,
1817
)
1918

2019
import numpy as np
@@ -39,6 +38,7 @@
3938
is_numeric_dtype,
4039
needs_i8_conversion,
4140
)
41+
from pandas.core.dtypes.dtypes import ArrowDtype
4242
from pandas.core.dtypes.generic import (
4343
ABCDataFrame,
4444
ABCSeries,
@@ -104,6 +104,7 @@
104104
NDFrameT,
105105
QuantileInterpolation,
106106
WindowingRankType,
107+
npt,
107108
)
108109

109110
from pandas import (
@@ -404,11 +405,12 @@ def _insert_on_column(self, result: DataFrame, obj: DataFrame) -> None:
404405
result[name] = extra_col
405406

406407
@property
407-
def _index_array(self):
408+
def _index_array(self) -> npt.NDArray[np.int64] | None:
408409
# TODO: why do we get here with e.g. MultiIndex?
409-
if needs_i8_conversion(self._on.dtype):
410-
idx = cast("PeriodIndex | DatetimeIndex | TimedeltaIndex", self._on)
411-
return idx.asi8
410+
if isinstance(self._on, (PeriodIndex, DatetimeIndex, TimedeltaIndex)):
411+
return self._on.asi8
412+
elif isinstance(self._on.dtype, ArrowDtype) and self._on.dtype.kind in "mM":
413+
return self._on.to_numpy(dtype=np.int64)
412414
return None
413415

414416
def _resolve_output(self, out: DataFrame, obj: DataFrame) -> DataFrame:
@@ -439,7 +441,7 @@ def _apply_series(
439441
self, homogeneous_func: Callable[..., ArrayLike], name: str | None = None
440442
) -> Series:
441443
"""
442-
Series version of _apply_blockwise
444+
Series version of _apply_columnwise
443445
"""
444446
obj = self._create_data(self._selected_obj)
445447

@@ -455,7 +457,7 @@ def _apply_series(
455457
index = self._slice_axis_for_step(obj.index, result)
456458
return obj._constructor(result, index=index, name=obj.name)
457459

458-
def _apply_blockwise(
460+
def _apply_columnwise(
459461
self,
460462
homogeneous_func: Callable[..., ArrayLike],
461463
name: str,
@@ -614,7 +616,7 @@ def calc(x):
614616
return result
615617

616618
if self.method == "single":
617-
return self._apply_blockwise(homogeneous_func, name, numeric_only)
619+
return self._apply_columnwise(homogeneous_func, name, numeric_only)
618620
else:
619621
return self._apply_tablewise(homogeneous_func, name, numeric_only)
620622

@@ -1232,7 +1234,9 @@ def calc(x):
12321234

12331235
return result
12341236

1235-
return self._apply_blockwise(homogeneous_func, name, numeric_only)[:: self.step]
1237+
return self._apply_columnwise(homogeneous_func, name, numeric_only)[
1238+
:: self.step
1239+
]
12361240

12371241
@doc(
12381242
_shared_docs["aggregate"],
@@ -1868,6 +1872,7 @@ def _validate(self) -> None:
18681872
if (
18691873
self.obj.empty
18701874
or isinstance(self._on, (DatetimeIndex, TimedeltaIndex, PeriodIndex))
1875+
or (isinstance(self._on.dtype, ArrowDtype) and self._on.dtype.kind in "mM")
18711876
) and isinstance(self.window, (str, BaseOffset, timedelta)):
18721877
self._validate_datetimelike_monotonic()
18731878

pandas/tests/window/test_timeseries_window.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import numpy as np
22
import pytest
33

4+
import pandas.util._test_decorators as td
5+
46
from pandas import (
57
DataFrame,
68
DatetimeIndex,
9+
Index,
710
MultiIndex,
811
NaT,
912
Series,
@@ -697,3 +700,16 @@ def test_nat_axis_error(msg, axis):
697700
with pytest.raises(ValueError, match=f"{msg} values must not have NaT"):
698701
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
699702
df.rolling("D", axis=axis).mean()
703+
704+
705+
@td.skip_if_no("pyarrow")
706+
def test_arrow_datetime_axis():
707+
# GH 55849
708+
expected = Series(
709+
np.arange(5, dtype=np.float64),
710+
index=Index(
711+
date_range("2020-01-01", periods=5), dtype="timestamp[ns][pyarrow]"
712+
),
713+
)
714+
result = expected.rolling("1D").sum()
715+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)