Skip to content

Commit 6a40f3b

Browse files
ENH: .shift optionally takes multiple periods (#54115)
1 parent 075c6c0 commit 6a40f3b

File tree

7 files changed

+279
-24
lines changed

7 files changed

+279
-24
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ Other enhancements
169169
- Added :meth:`ExtensionArray.interpolate` used by :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` (:issue:`53659`)
170170
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
171171
- Added a new parameter ``by_row`` to :meth:`Series.apply` and :meth:`DataFrame.apply`. When set to ``False`` the supplied callables will always operate on the whole Series or DataFrame (:issue:`53400`, :issue:`53601`).
172+
- :meth:`DataFrame.shift` and :meth:`Series.shift` now allow shifting by multiple periods by supplying a list of periods (:issue:`44424`)
172173
- Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`)
173174
- Improved error message when :meth:`DataFrameGroupBy.agg` failed (:issue:`52930`)
174175
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)

pandas/core/frame.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -5578,20 +5578,48 @@ def _replace_columnwise(
55785578
@doc(NDFrame.shift, klass=_shared_doc_kwargs["klass"])
55795579
def shift(
55805580
self,
5581-
periods: int = 1,
5581+
periods: int | Sequence[int] = 1,
55825582
freq: Frequency | None = None,
55835583
axis: Axis = 0,
55845584
fill_value: Hashable = lib.no_default,
5585+
suffix: str | None = None,
55855586
) -> DataFrame:
5586-
axis = self._get_axis_number(axis)
5587-
55885587
if freq is not None and fill_value is not lib.no_default:
55895588
# GH#53832
55905589
raise ValueError(
55915590
"Cannot pass both 'freq' and 'fill_value' to "
55925591
f"{type(self).__name__}.shift"
55935592
)
55945593

5594+
axis = self._get_axis_number(axis)
5595+
5596+
if is_list_like(periods):
5597+
periods = cast(Sequence, periods)
5598+
if axis == 1:
5599+
raise ValueError(
5600+
"If `periods` contains multiple shifts, `axis` cannot be 1."
5601+
)
5602+
if len(periods) == 0:
5603+
raise ValueError("If `periods` is an iterable, it cannot be empty.")
5604+
from pandas.core.reshape.concat import concat
5605+
5606+
shifted_dataframes = []
5607+
for period in periods:
5608+
if not is_integer(period):
5609+
raise TypeError(
5610+
f"Periods must be integer, but {period} is {type(period)}."
5611+
)
5612+
period = cast(int, period)
5613+
shifted_dataframes.append(
5614+
super()
5615+
.shift(periods=period, freq=freq, axis=axis, fill_value=fill_value)
5616+
.add_suffix(f"{suffix}_{period}" if suffix else f"_{period}")
5617+
)
5618+
return concat(shifted_dataframes, axis=1)
5619+
elif suffix:
5620+
raise ValueError("Cannot specify `suffix` if `periods` is an int.")
5621+
periods = cast(int, periods)
5622+
55955623
ncols = len(self.columns)
55965624
arrays = self._mgr.arrays
55975625
if axis == 1 and periods != 0 and ncols > 0 and freq is None:

pandas/core/generic.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -10545,11 +10545,12 @@ def mask(
1054510545
@doc(klass=_shared_doc_kwargs["klass"])
1054610546
def shift(
1054710547
self,
10548-
periods: int = 1,
10548+
periods: int | Sequence[int] = 1,
1054910549
freq=None,
1055010550
axis: Axis = 0,
1055110551
fill_value: Hashable = lib.no_default,
10552-
) -> Self:
10552+
suffix: str | None = None,
10553+
) -> Self | DataFrame:
1055310554
"""
1055410555
Shift index by desired number of periods with an optional time `freq`.
1055510556
@@ -10562,8 +10563,13 @@ def shift(
1056210563
1056310564
Parameters
1056410565
----------
10565-
periods : int
10566+
periods : int or Sequence
1056610567
Number of periods to shift. Can be positive or negative.
10568+
If an iterable of ints, the data will be shifted once by each int.
10569+
This is equivalent to shifting by one value at a time and
10570+
concatenating all resulting frames. The resulting columns will have
10571+
the shift suffixed to their column names. For multiple periods,
10572+
axis must not be 1.
1056710573
freq : DateOffset, tseries.offsets, timedelta, or str, optional
1056810574
Offset to use from the tseries module or time rule (e.g. 'EOM').
1056910575
If `freq` is specified then the index values are shifted but the
@@ -10580,6 +10586,9 @@ def shift(
1058010586
For numeric data, ``np.nan`` is used.
1058110587
For datetime, timedelta, or period data, etc. :attr:`NaT` is used.
1058210588
For extension dtypes, ``self.dtype.na_value`` is used.
10589+
suffix : str, optional
10590+
If str and periods is an iterable, this is added after the column
10591+
name and before the shift value for each shifted column name.
1058310592
1058410593
Returns
1058510594
-------
@@ -10645,6 +10654,14 @@ def shift(
1064510654
2020-01-06 15 18 22
1064610655
2020-01-07 30 33 37
1064710656
2020-01-08 45 48 52
10657+
10658+
>>> df['Col1'].shift(periods=[0, 1, 2])
10659+
Col1_0 Col1_1 Col1_2
10660+
2020-01-01 10 NaN NaN
10661+
2020-01-02 20 10.0 NaN
10662+
2020-01-03 15 20.0 10.0
10663+
2020-01-04 30 15.0 20.0
10664+
2020-01-05 45 30.0 15.0
1064810665
"""
1064910666
axis = self._get_axis_number(axis)
1065010667

@@ -10658,6 +10675,12 @@ def shift(
1065810675
if periods == 0:
1065910676
return self.copy(deep=None)
1066010677

10678+
if is_list_like(periods) and isinstance(self, ABCSeries):
10679+
return self.to_frame().shift(
10680+
periods=periods, freq=freq, axis=axis, fill_value=fill_value
10681+
)
10682+
periods = cast(int, periods)
10683+
1066110684
if freq is None:
1066210685
# when freq is None, data is shifted, index is not
1066310686
axis = self._get_axis_number(axis)

pandas/core/groupby/groupby.py

+68-17
Original file line numberDiff line numberDiff line change
@@ -4932,10 +4932,11 @@ def cummax(
49324932
@Substitution(name="groupby")
49334933
def shift(
49344934
self,
4935-
periods: int = 1,
4935+
periods: int | Sequence[int] = 1,
49364936
freq=None,
49374937
axis: Axis | lib.NoDefault = lib.no_default,
49384938
fill_value=lib.no_default,
4939+
suffix: str | None = None,
49394940
):
49404941
"""
49414942
Shift each group by periods observations.
@@ -4944,8 +4945,9 @@ def shift(
49444945
49454946
Parameters
49464947
----------
4947-
periods : int, default 1
4948-
Number of periods to shift.
4948+
periods : int | Sequence[int], default 1
4949+
Number of periods to shift. If a list of values, shift each group by
4950+
each period.
49494951
freq : str, optional
49504952
Frequency string.
49514953
axis : axis to shift, default 0
@@ -4961,6 +4963,10 @@ def shift(
49614963
.. versionchanged:: 2.1.0
49624964
Will raise a ``ValueError`` if ``freq`` is provided too.
49634965
4966+
suffix : str, optional
4967+
A string to add to each shifted column if there are multiple periods.
4968+
Ignored otherwise.
4969+
49644970
Returns
49654971
-------
49664972
Series or DataFrame
@@ -5014,25 +5020,70 @@ def shift(
50145020
else:
50155021
axis = 0
50165022

5017-
if freq is not None or axis != 0:
5018-
f = lambda x: x.shift(periods, freq, axis, fill_value)
5019-
return self._python_apply_general(f, self._selected_obj, is_transform=True)
5023+
if is_list_like(periods):
5024+
if axis == 1:
5025+
raise ValueError(
5026+
"If `periods` contains multiple shifts, `axis` cannot be 1."
5027+
)
5028+
periods = cast(Sequence, periods)
5029+
if len(periods) == 0:
5030+
raise ValueError("If `periods` is an iterable, it cannot be empty.")
5031+
from pandas.core.reshape.concat import concat
50205032

5021-
if fill_value is lib.no_default:
5022-
fill_value = None
5023-
ids, _, ngroups = self.grouper.group_info
5024-
res_indexer = np.zeros(len(ids), dtype=np.int64)
5033+
add_suffix = True
5034+
else:
5035+
if not is_integer(periods):
5036+
raise TypeError(
5037+
f"Periods must be integer, but {periods} is {type(periods)}."
5038+
)
5039+
if suffix:
5040+
raise ValueError("Cannot specify `suffix` if `periods` is an int.")
5041+
periods = [cast(int, periods)]
5042+
add_suffix = False
5043+
5044+
shifted_dataframes = []
5045+
for period in periods:
5046+
if not is_integer(period):
5047+
raise TypeError(
5048+
f"Periods must be integer, but {period} is {type(period)}."
5049+
)
5050+
period = cast(int, period)
5051+
if freq is not None or axis != 0:
5052+
f = lambda x: x.shift(
5053+
period, freq, axis, fill_value # pylint: disable=cell-var-from-loop
5054+
)
5055+
shifted = self._python_apply_general(
5056+
f, self._selected_obj, is_transform=True
5057+
)
5058+
else:
5059+
if fill_value is lib.no_default:
5060+
fill_value = None
5061+
ids, _, ngroups = self.grouper.group_info
5062+
res_indexer = np.zeros(len(ids), dtype=np.int64)
50255063

5026-
libgroupby.group_shift_indexer(res_indexer, ids, ngroups, periods)
5064+
libgroupby.group_shift_indexer(res_indexer, ids, ngroups, period)
50275065

5028-
obj = self._obj_with_exclusions
5066+
obj = self._obj_with_exclusions
50295067

5030-
res = obj._reindex_with_indexers(
5031-
{self.axis: (obj.axes[self.axis], res_indexer)},
5032-
fill_value=fill_value,
5033-
allow_dups=True,
5068+
shifted = obj._reindex_with_indexers(
5069+
{self.axis: (obj.axes[self.axis], res_indexer)},
5070+
fill_value=fill_value,
5071+
allow_dups=True,
5072+
)
5073+
5074+
if add_suffix:
5075+
if isinstance(shifted, Series):
5076+
shifted = cast(NDFrameT, shifted.to_frame())
5077+
shifted = shifted.add_suffix(
5078+
f"{suffix}_{period}" if suffix else f"_{period}"
5079+
)
5080+
shifted_dataframes.append(cast(Union[Series, DataFrame], shifted))
5081+
5082+
return (
5083+
shifted_dataframes[0]
5084+
if len(shifted_dataframes) == 1
5085+
else concat(shifted_dataframes, axis=1)
50345086
)
5035-
return res
50365087

50375088
@final
50385089
@Substitution(name="groupby")

pandas/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3030,7 +3030,7 @@ def autocorr(self, lag: int = 1) -> float:
30303030
>>> s.autocorr()
30313031
nan
30323032
"""
3033-
return self.corr(self.shift(lag))
3033+
return self.corr(cast(Series, self.shift(lag)))
30343034

30353035
def dot(self, other: AnyArrayLike) -> Series | np.ndarray:
30363036
"""

pandas/tests/frame/methods/test_shift.py

+78
Original file line numberDiff line numberDiff line change
@@ -665,3 +665,81 @@ def test_shift_with_offsets_freq(self):
665665
index=date_range(start="02/01/2000", end="02/01/2000", periods=3),
666666
)
667667
tm.assert_frame_equal(shifted, expected)
668+
669+
def test_shift_with_iterable_basic_functionality(self):
670+
# GH#44424
671+
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
672+
shifts = [0, 1, 2]
673+
674+
df = DataFrame(data)
675+
shifted = df.shift(shifts)
676+
677+
expected = DataFrame(
678+
{
679+
"a_0": [1, 2, 3],
680+
"b_0": [4, 5, 6],
681+
"a_1": [np.NaN, 1.0, 2.0],
682+
"b_1": [np.NaN, 4.0, 5.0],
683+
"a_2": [np.NaN, np.NaN, 1.0],
684+
"b_2": [np.NaN, np.NaN, 4.0],
685+
}
686+
)
687+
tm.assert_frame_equal(expected, shifted)
688+
689+
def test_shift_with_iterable_series(self):
690+
# GH#44424
691+
data = {"a": [1, 2, 3]}
692+
shifts = [0, 1, 2]
693+
694+
df = DataFrame(data)
695+
s: Series = df["a"]
696+
tm.assert_frame_equal(s.shift(shifts), df.shift(shifts))
697+
698+
def test_shift_with_iterable_freq_and_fill_value(self):
699+
# GH#44424
700+
df = DataFrame(
701+
np.random.randn(5), index=date_range("1/1/2000", periods=5, freq="H")
702+
)
703+
704+
tm.assert_frame_equal(
705+
# rename because shift with an iterable leads to str column names
706+
df.shift([1], fill_value=1).rename(columns=lambda x: int(x[0])),
707+
df.shift(1, fill_value=1),
708+
)
709+
710+
tm.assert_frame_equal(
711+
df.shift([1], freq="H").rename(columns=lambda x: int(x[0])),
712+
df.shift(1, freq="H"),
713+
)
714+
715+
msg = r"Cannot pass both 'freq' and 'fill_value' to.*"
716+
with pytest.raises(ValueError, match=msg):
717+
df.shift([1, 2], fill_value=1, freq="H")
718+
719+
def test_shift_with_iterable_check_other_arguments(self):
720+
# GH#44424
721+
data = {"a": [1, 2], "b": [4, 5]}
722+
shifts = [0, 1]
723+
df = DataFrame(data)
724+
725+
# test suffix
726+
shifted = df[["a"]].shift(shifts, suffix="_suffix")
727+
expected = DataFrame({"a_suffix_0": [1, 2], "a_suffix_1": [np.nan, 1.0]})
728+
tm.assert_frame_equal(shifted, expected)
729+
730+
# check bad inputs when doing multiple shifts
731+
msg = "If `periods` contains multiple shifts, `axis` cannot be 1."
732+
with pytest.raises(ValueError, match=msg):
733+
df.shift(shifts, axis=1)
734+
735+
msg = "Periods must be integer, but s is <class 'str'>."
736+
with pytest.raises(TypeError, match=msg):
737+
df.shift(["s"])
738+
739+
msg = "If `periods` is an iterable, it cannot be empty."
740+
with pytest.raises(ValueError, match=msg):
741+
df.shift([])
742+
743+
msg = "Cannot specify `suffix` if `periods` is an int."
744+
with pytest.raises(ValueError, match=msg):
745+
df.shift(1, suffix="fails")

0 commit comments

Comments
 (0)