Skip to content

ENH: .shift optionally takes multiple periods #54115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4e1a84e
init
jona-sassenhagen Jul 13, 2023
db9cd03
precommit
jona-sassenhagen Jul 13, 2023
9def045
slightly update test
jona-sassenhagen Jul 13, 2023
b7ea297
Fix groupby API tests
jona-sassenhagen Jul 13, 2023
299927a
mostly types, also exclude groupby
jona-sassenhagen Jul 14, 2023
6cf632e
fix test
jona-sassenhagen Jul 14, 2023
23d9142
mypy
jona-sassenhagen Jul 14, 2023
b78d7d6
fix docstring
jona-sassenhagen Jul 14, 2023
cabd28c
change how futurewarning is handled in the test
jona-sassenhagen Jul 15, 2023
5017721
fix docstring
jona-sassenhagen Jul 15, 2023
6f3ec9b
remove debug statement
jona-sassenhagen Jul 15, 2023
8d085f3
address comments
jona-sassenhagen Jul 16, 2023
6e66a19
refactor
jona-sassenhagen Jul 16, 2023
f7ea7a3
handle default
jona-sassenhagen Jul 17, 2023
f8e29d9
pylint
jona-sassenhagen Jul 17, 2023
e54ba33
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 17, 2023
715c397
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 18, 2023
21e8e70
merge conflicts and mypy
jona-sassenhagen Jul 18, 2023
0e78963
split tests, remove checking for None default
jona-sassenhagen Jul 18, 2023
2c602e2
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 18, 2023
d9bf54f
address comments
jona-sassenhagen Jul 19, 2023
778b7e9
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 19, 2023
ad49861
Merge branch 'jona/44424_shift' of github.com:jona-sassenhagen/pandas…
jona-sassenhagen Jul 19, 2023
28469db
mypy
jona-sassenhagen Jul 19, 2023
c8e5bad
mypy again
jona-sassenhagen Jul 19, 2023
cb4013d
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 19, 2023
226fa6f
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 21, 2023
cb49cac
black
jona-sassenhagen Jul 21, 2023
44b0866
address comments
jona-sassenhagen Jul 25, 2023
5257292
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 25, 2023
eff4ed2
mypy
jona-sassenhagen Jul 25, 2023
04d4513
Merge branch 'jona/44424_shift' of github.com:jona-sassenhagen/pandas…
jona-sassenhagen Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Other enhancements
- Added :meth:`ExtensionArray.interpolate` used by :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` (:issue:`53659`)
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
- Added a new parameter ``by_row`` to :meth:`Series.apply` and :meth:`DataFrame.apply`. When set to ``False`` the supplied callables will always operate on the whole Series or DataFrame (:issue:`53400`, :issue:`53601`).
- :meth:`DataFrame.shift` and :meth:`Series.shift` now allow shifting by multiple periods by supplying a list of periods (:issue:`44424`)
- Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`)
- Improved error message when :meth:`DataFrameGroupBy.agg` failed (:issue:`52930`)
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)
Expand Down
25 changes: 24 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5501,13 +5501,36 @@ def _replace_columnwise(
@doc(NDFrame.shift, klass=_shared_doc_kwargs["klass"])
def shift(
self,
periods: int = 1,
periods: int | Iterable[int] = 1,
freq: Frequency | None = None,
axis: Axis = 0,
fill_value: Hashable = lib.no_default,
suffix: str | None = None,
) -> DataFrame:
axis = self._get_axis_number(axis)

if is_list_like(periods):
if axis == 1:
raise ValueError(
"If `periods` contains multiple shifts, `axis` cannot be 1."
)
if len(periods) == 0:
raise ValueError("If `periods` is an iterable, it cannot be empty.")
from pandas.core.reshape.concat import concat

result = []
for period in periods:
if not isinstance(period, int):
raise TypeError(
f"Periods must be integer, but {period} is {type(period)}."
)
result.append(
super()
.shift(periods=period, freq=freq, axis=axis, fill_value=fill_value)
.add_suffix(f"{suffix}_{period}" if suffix else f"_{period}")
)
return concat(result, axis=1) if result else self

if freq is not None and fill_value is not lib.no_default:
# GH#53832
raise ValueError(
Expand Down
27 changes: 25 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@
if TYPE_CHECKING:
from collections.abc import (
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
Expand Down Expand Up @@ -10521,10 +10522,11 @@ def mask(
@doc(klass=_shared_doc_kwargs["klass"])
def shift(
self,
periods: int = 1,
periods: int | Iterable = 1,
freq=None,
axis: Axis = 0,
fill_value: Hashable = lib.no_default,
suffix: str | None = None,
) -> Self:
"""
Shift index by desired number of periods with an optional time `freq`.
Expand All @@ -10538,8 +10540,13 @@ def shift(

Parameters
----------
periods : int
periods : int or Iterable
Number of periods to shift. Can be positive or negative.
If an iterable of ints, the data will be shifted once by each int.
This is equivalent to shifting by one value at a time and
concatenating all resulting frames. The resulting columns will have
the shift suffixed to their column names. For multiple periods,
axis must not be 1.
freq : DateOffset, tseries.offsets, timedelta, or str, optional
Offset to use from the tseries module or time rule (e.g. 'EOM').
If `freq` is specified then the index values are shifted but the
Expand All @@ -10556,6 +10563,9 @@ def shift(
For numeric data, ``np.nan`` is used.
For datetime, timedelta, or period data, etc. :attr:`NaT` is used.
For extension dtypes, ``self.dtype.na_value`` is used.
suffix: str, optional
If str and periods is an iterable, this is added after the column
name and before the shift value for each shifted column name.

Returns
-------
Expand Down Expand Up @@ -10621,6 +10631,14 @@ def shift(
2020-01-06 15 18 22
2020-01-07 30 33 37
2020-01-08 45 48 52

>>> df['Col1'].shift(periods=[1, 2])
Col1_0 Col1_1 Col1_2
2020-01-01 10 NaN NaN
2020-01-02 20 10.0 NaN
2020-01-03 15 20.0 10.0
2020-01-04 30 15.0 20.0
2020-01-05 45 30.0 15.0
"""
axis = self._get_axis_number(axis)

Expand All @@ -10634,6 +10652,11 @@ def shift(
if periods == 0:
return self.copy(deep=None)

if is_list_like(periods) and len(self.shape) == 1:
return self.to_frame().shift(
periods=periods, freq=freq, axis=axis, fill_value=fill_value
)

if freq is None:
# when freq is None, data is shifted, index is not
axis = self._get_axis_number(axis)
Expand Down
3 changes: 3 additions & 0 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -4866,6 +4866,7 @@ def shift(
freq=None,
axis: Axis | lib.NoDefault = lib.no_default,
fill_value=None,
suffix: str | None = None,
):
"""
Shift each group by periods observations.
Expand All @@ -4887,6 +4888,8 @@ def shift(

fill_value : optional
The scalar value to use for newly introduced missing values.
suffix : str, optional
An optional suffix to append when there are multiple periods.

Returns
-------
Expand Down
42 changes: 42 additions & 0 deletions pandas/tests/frame/methods/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,45 @@ def test_shift_axis1_many_periods(self):

shifted2 = df.shift(-6, axis=1, fill_value=None)
tm.assert_frame_equal(shifted2, expected)

def test_shift_with_iterable(self):
# GH#44424
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
shifts = [0, 1, 2]

df = DataFrame(data)
shifted = df.shift(shifts)

expected = DataFrame(
{
"a_0": [1, 2, 3],
"b_0": [4, 5, 6],
"a_1": [np.NaN, 1.0, 2.0],
"b_1": [np.NaN, 4.0, 5.0],
"a_2": [np.NaN, np.NaN, 1.0],
"b_2": [np.NaN, np.NaN, 4.0],
}
)
tm.assert_frame_equal(expected, shifted)

# test pd.Series
s: Series = df["a"]
df_one_column: DataFrame = df[["a"]]
tm.assert_frame_equal(s.shift(shifts), df_one_column.shift(shifts))

# test suffix
columns = df[["a"]].shift(shifts, suffix="_suffix").columns
assert columns.tolist() == ["a_suffix_0", "a_suffix_1", "a_suffix_2"]

# check bad inputs when doing multiple shifts
msg = "If `periods` contains multiple shifts, `axis` cannot be 1."
with pytest.raises(ValueError, match=msg):
df.shift([1, 2], axis=1)

msg = "Periods must be integer, but s is <class 'str'>."
with pytest.raises(TypeError, match=msg):
df.shift(["s"])

msg = "If `periods` is an iterable, it cannot be empty."
with pytest.raises(ValueError, match=msg):
df.shift([])