Skip to content

ENH: .shift optionally takes multiple periods #54115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4e1a84e
init
jona-sassenhagen Jul 13, 2023
db9cd03
precommit
jona-sassenhagen Jul 13, 2023
9def045
slightly update test
jona-sassenhagen Jul 13, 2023
b7ea297
Fix groupby API tests
jona-sassenhagen Jul 13, 2023
299927a
mostly types, also exclude groupby
jona-sassenhagen Jul 14, 2023
6cf632e
fix test
jona-sassenhagen Jul 14, 2023
23d9142
mypy
jona-sassenhagen Jul 14, 2023
b78d7d6
fix docstring
jona-sassenhagen Jul 14, 2023
cabd28c
change how futurewarning is handled in the test
jona-sassenhagen Jul 15, 2023
5017721
fix docstring
jona-sassenhagen Jul 15, 2023
6f3ec9b
remove debug statement
jona-sassenhagen Jul 15, 2023
8d085f3
address comments
jona-sassenhagen Jul 16, 2023
6e66a19
refactor
jona-sassenhagen Jul 16, 2023
f7ea7a3
handle default
jona-sassenhagen Jul 17, 2023
f8e29d9
pylint
jona-sassenhagen Jul 17, 2023
e54ba33
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 17, 2023
715c397
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 18, 2023
21e8e70
merge conflicts and mypy
jona-sassenhagen Jul 18, 2023
0e78963
split tests, remove checking for None default
jona-sassenhagen Jul 18, 2023
2c602e2
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 18, 2023
d9bf54f
address comments
jona-sassenhagen Jul 19, 2023
778b7e9
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 19, 2023
ad49861
Merge branch 'jona/44424_shift' of github.com:jona-sassenhagen/pandas…
jona-sassenhagen Jul 19, 2023
28469db
mypy
jona-sassenhagen Jul 19, 2023
c8e5bad
mypy again
jona-sassenhagen Jul 19, 2023
cb4013d
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 19, 2023
226fa6f
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 21, 2023
cb49cac
black
jona-sassenhagen Jul 21, 2023
44b0866
address comments
jona-sassenhagen Jul 25, 2023
5257292
Merge branch 'main' into jona/44424_shift
jona-sassenhagen Jul 25, 2023
eff4ed2
mypy
jona-sassenhagen Jul 25, 2023
04d4513
Merge branch 'jona/44424_shift' of github.com:jona-sassenhagen/pandas…
jona-sassenhagen Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ Other enhancements
- Added :meth:`ExtensionArray.interpolate` used by :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` (:issue:`53659`)
- Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`)
- Added a new parameter ``by_row`` to :meth:`Series.apply` and :meth:`DataFrame.apply`. When set to ``False`` the supplied callables will always operate on the whole Series or DataFrame (:issue:`53400`, :issue:`53601`).
- :meth:`DataFrame.shift` and :meth:`Series.shift` now allow shifting by multiple periods by supplying a list of periods (:issue:`44424`)
- Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`)
- Improved error message when :meth:`DataFrameGroupBy.agg` failed (:issue:`52930`)
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`)
Expand Down
34 changes: 31 additions & 3 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5550,20 +5550,48 @@ def _replace_columnwise(
@doc(NDFrame.shift, klass=_shared_doc_kwargs["klass"])
def shift(
self,
periods: int = 1,
periods: int | Sequence[int] = 1,
freq: Frequency | None = None,
axis: Axis = 0,
fill_value: Hashable = lib.no_default,
suffix: str | None = None,
) -> DataFrame:
axis = self._get_axis_number(axis)

if freq is not None and fill_value is not lib.no_default:
# GH#53832
raise ValueError(
"Cannot pass both 'freq' and 'fill_value' to "
f"{type(self).__name__}.shift"
)

axis = self._get_axis_number(axis)

if is_list_like(periods):
# periods is not necessarily a list, but otherwise mypy complains.
periods = cast(Sequence, periods)
if axis == 1:
raise ValueError(
"If `periods` contains multiple shifts, `axis` cannot be 1."
)
if len(periods) == 0:
raise ValueError("If `periods` is an iterable, it cannot be empty.")
from pandas.core.reshape.concat import concat

shifted_dataframes = []
for period in periods:
if not isinstance(period, int):
raise TypeError(
f"Periods must be integer, but {period} is {type(period)}."
)
shifted_dataframes.append(
super()
.shift(periods=period, freq=freq, axis=axis, fill_value=fill_value)
.add_suffix(f"{suffix}_{period}" if suffix else f"_{period}")
)
return concat(shifted_dataframes, axis=1)
elif suffix:
raise ValueError("Cannot specify `suffix` if `periods` is an int.")
periods = cast(int, periods)

ncols = len(self.columns)
arrays = self._mgr.arrays
if axis == 1 and periods != 0 and ncols > 0 and freq is None:
Expand Down
29 changes: 26 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10545,11 +10545,12 @@ def mask(
@doc(klass=_shared_doc_kwargs["klass"])
def shift(
self,
periods: int = 1,
periods: int | Sequence[int] = 1,
freq=None,
axis: Axis = 0,
fill_value: Hashable = lib.no_default,
) -> Self:
suffix: str | None = None,
) -> Self | DataFrame:
"""
Shift index by desired number of periods with an optional time `freq`.

Expand All @@ -10562,8 +10563,13 @@ def shift(

Parameters
----------
periods : int
periods : int or Sequence
Number of periods to shift. Can be positive or negative.
If an iterable of ints, the data will be shifted once by each int.
This is equivalent to shifting by one value at a time and
concatenating all resulting frames. The resulting columns will have
the shift suffixed to their column names. For multiple periods,
axis must not be 1.
freq : DateOffset, tseries.offsets, timedelta, or str, optional
Offset to use from the tseries module or time rule (e.g. 'EOM').
If `freq` is specified then the index values are shifted but the
Expand All @@ -10580,6 +10586,9 @@ def shift(
For numeric data, ``np.nan`` is used.
For datetime, timedelta, or period data, etc. :attr:`NaT` is used.
For extension dtypes, ``self.dtype.na_value`` is used.
suffix : str, optional
If str and periods is an iterable, this is added after the column
name and before the shift value for each shifted column name.

Returns
-------
Expand Down Expand Up @@ -10645,6 +10654,14 @@ def shift(
2020-01-06 15 18 22
2020-01-07 30 33 37
2020-01-08 45 48 52

>>> df['Col1'].shift(periods=[0, 1, 2])
Col1_0 Col1_1 Col1_2
2020-01-01 10 NaN NaN
2020-01-02 20 10.0 NaN
2020-01-03 15 20.0 10.0
2020-01-04 30 15.0 20.0
2020-01-05 45 30.0 15.0
"""
axis = self._get_axis_number(axis)

Expand All @@ -10658,6 +10675,12 @@ def shift(
if periods == 0:
return self.copy(deep=None)

if is_list_like(periods) and isinstance(self, ABCSeries):
return self.to_frame().shift(
periods=periods, freq=freq, axis=axis, fill_value=fill_value
)
periods = cast(int, periods)

if freq is None:
# when freq is None, data is shifted, index is not
axis = self._get_axis_number(axis)
Expand Down
84 changes: 67 additions & 17 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -4925,10 +4925,11 @@ def cummax(
@Substitution(name="groupby")
def shift(
self,
periods: int = 1,
periods: int | Sequence[int] = 1,
freq=None,
axis: Axis | lib.NoDefault = lib.no_default,
fill_value=lib.no_default,
suffix: str | None = None,
):
"""
Shift each group by periods observations.
Expand All @@ -4937,8 +4938,9 @@ def shift(

Parameters
----------
periods : int, default 1
Number of periods to shift.
periods : int | Sequence[int], default 1
Number of periods to shift. If a list of values, shift each group by
each period.
freq : str, optional
Frequency string.
axis : axis to shift, default 0
Expand All @@ -4954,6 +4956,10 @@ def shift(
.. versionchanged:: 2.1.0
Will raise a ``ValueError`` if ``freq`` is provided too.

suffix : str, optional
A string to add to each shifted column if there are multiple periods.
Ignored otherwise.

Returns
-------
Series or DataFrame
Expand Down Expand Up @@ -5007,25 +5013,69 @@ def shift(
else:
axis = 0

if freq is not None or axis != 0:
f = lambda x: x.shift(periods, freq, axis, fill_value)
return self._python_apply_general(f, self._selected_obj, is_transform=True)
if is_list_like(periods):
if axis == 1:
raise ValueError(
"If `periods` contains multiple shifts, `axis` cannot be 1."
)
periods = cast(Sequence, periods)
if len(periods) == 0:
raise ValueError("If `periods` is an iterable, it cannot be empty.")
from pandas.core.reshape.concat import concat

if fill_value is lib.no_default:
fill_value = None
ids, _, ngroups = self.grouper.group_info
res_indexer = np.zeros(len(ids), dtype=np.int64)
add_suffix = True
else:
if not isinstance(periods, int):
raise TypeError(
f"Periods must be integer, but {periods} is {type(periods)}."
)
if suffix:
raise ValueError("Cannot specify `suffix` if `periods` is an int.")
periods = [periods]
add_suffix = False

shifted_dataframes = []
for period in periods:
if not isinstance(period, int):
raise TypeError(
f"Periods must be integer, but {period} is {type(period)}."
)
if freq is not None or axis != 0:
f = lambda x: x.shift(
period, freq, axis, fill_value # pylint: disable=cell-var-from-loop
)
shifted = self._python_apply_general(
f, self._selected_obj, is_transform=True
)
else:
if fill_value is lib.no_default:
fill_value = None
ids, _, ngroups = self.grouper.group_info
res_indexer = np.zeros(len(ids), dtype=np.int64)

libgroupby.group_shift_indexer(res_indexer, ids, ngroups, periods)
libgroupby.group_shift_indexer(res_indexer, ids, ngroups, period)

obj = self._obj_with_exclusions
obj = self._obj_with_exclusions

res = obj._reindex_with_indexers(
{self.axis: (obj.axes[self.axis], res_indexer)},
fill_value=fill_value,
allow_dups=True,
shifted = obj._reindex_with_indexers(
{self.axis: (obj.axes[self.axis], res_indexer)},
fill_value=fill_value,
allow_dups=True,
)

if add_suffix:
if isinstance(shifted, Series):
shifted = cast(NDFrameT, shifted.to_frame())
shifted = shifted.add_suffix(
f"{suffix}_{period}" if suffix else f"_{period}"
)
shifted_dataframes.append(cast(Union[Series, DataFrame], shifted))

return (
shifted_dataframes[0]
if len(shifted_dataframes) == 1
else concat(shifted_dataframes, axis=1)
)
return res

@final
@Substitution(name="groupby")
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3012,7 +3012,7 @@ def autocorr(self, lag: int = 1) -> float:
>>> s.autocorr()
nan
"""
return self.corr(self.shift(lag))
return self.corr(cast(Series, self.shift(lag)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary because Series.shift can now return a dataframe if there are multiple periods. Mypy complains at this line because it wants Series and not Series | Dataframe, but because there is only one lag, it's always Series, so we can just cast.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add type-overloads to handle this. Sometimes they can get gnarly (see e.g. concat), but I think this one should be relatively straight forward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to if that's ok? There is no Series.shift to overload so I feel like it would be confusing to newcomers like me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you insist I will try!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise I think this is done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rhshadrach , do you have time for another round/sign this off?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with opening an issue regarding this once it's merged for someone to followup on.


def dot(self, other: AnyArrayLike) -> Series | np.ndarray:
"""
Expand Down
78 changes: 78 additions & 0 deletions pandas/tests/frame/methods/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,81 @@ def test_shift_with_offsets_freq(self):
index=date_range(start="02/01/2000", end="02/01/2000", periods=3),
)
tm.assert_frame_equal(shifted, expected)

def test_shift_with_iterable_basic_functionality(self):
# GH#44424
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
shifts = [0, 1, 2]

df = DataFrame(data)
shifted = df.shift(shifts)

expected = DataFrame(
{
"a_0": [1, 2, 3],
"b_0": [4, 5, 6],
"a_1": [np.NaN, 1.0, 2.0],
"b_1": [np.NaN, 4.0, 5.0],
"a_2": [np.NaN, np.NaN, 1.0],
"b_2": [np.NaN, np.NaN, 4.0],
}
)
tm.assert_frame_equal(expected, shifted)

def test_shift_with_iterable_series(self):
# GH#44424
data = {"a": [1, 2, 3]}
shifts = [0, 1, 2]

df = DataFrame(data)
s: Series = df["a"]
tm.assert_frame_equal(s.shift(shifts), df.shift(shifts))

def test_shift_with_iterable_freq_and_fill_value(self):
# GH#44424
df = DataFrame(
np.random.randn(5), index=date_range("1/1/2000", periods=5, freq="H")
)

tm.assert_frame_equal(
# rename because shift with an iterable leads to str column names
df.shift([1], fill_value=1).rename(columns=lambda x: int(x[0])),
df.shift(1, fill_value=1),
)

tm.assert_frame_equal(
df.shift([1], freq="H").rename(columns=lambda x: int(x[0])),
df.shift(1, freq="H"),
)

msg = r"Cannot pass both 'freq' and 'fill_value' to.*"
with pytest.raises(ValueError, match=msg):
df.shift([1, 2], fill_value=1, freq="H")

def test_shift_with_iterable_check_other_arguments(self):
# GH#44424
data = {"a": [1, 2], "b": [4, 5]}
shifts = [0, 1]
df = DataFrame(data)

# test suffix
shifted = df[["a"]].shift(shifts, suffix="_suffix")
expected = DataFrame({"a_suffix_0": [1, 2], "a_suffix_1": [np.nan, 1.0]})
tm.assert_frame_equal(shifted, expected)

# check bad inputs when doing multiple shifts
msg = "If `periods` contains multiple shifts, `axis` cannot be 1."
with pytest.raises(ValueError, match=msg):
df.shift(shifts, axis=1)

msg = "Periods must be integer, but s is <class 'str'>."
with pytest.raises(TypeError, match=msg):
df.shift(["s"])

msg = "If `periods` is an iterable, it cannot be empty."
with pytest.raises(ValueError, match=msg):
df.shift([])

msg = "Cannot specify `suffix` if `periods` is an int."
with pytest.raises(ValueError, match=msg):
df.shift(1, suffix="fails")
Loading