diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 8cf7543ae075b..46ca987d7b95b 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -168,6 +168,7 @@ Other enhancements - Added :meth:`ExtensionArray.interpolate` used by :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` (:issue:`53659`) - Added ``engine_kwargs`` parameter to :meth:`DataFrame.to_excel` (:issue:`53220`) - Added a new parameter ``by_row`` to :meth:`Series.apply` and :meth:`DataFrame.apply`. When set to ``False`` the supplied callables will always operate on the whole Series or DataFrame (:issue:`53400`, :issue:`53601`). +- :meth:`DataFrame.shift` and :meth:`Series.shift` now allow shifting by multiple periods by supplying a list of periods (:issue:`44424`) - Groupby aggregations (such as :meth:`DataFrameGroupby.sum`) now can preserve the dtype of the input instead of casting to ``float64`` (:issue:`44952`) - Improved error message when :meth:`DataFrameGroupBy.agg` failed (:issue:`52930`) - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 75a177f0969ca..e700b63f35485 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -5578,13 +5578,12 @@ def _replace_columnwise( @doc(NDFrame.shift, klass=_shared_doc_kwargs["klass"]) def shift( self, - periods: int = 1, + periods: int | Sequence[int] = 1, freq: Frequency | None = None, axis: Axis = 0, fill_value: Hashable = lib.no_default, + suffix: str | None = None, ) -> DataFrame: - axis = self._get_axis_number(axis) - if freq is not None and fill_value is not lib.no_default: # GH#53832 raise ValueError( @@ -5592,6 +5591,35 @@ def shift( f"{type(self).__name__}.shift" ) + axis = self._get_axis_number(axis) + + if is_list_like(periods): + periods = cast(Sequence, periods) + if axis == 1: + raise ValueError( + "If `periods` contains multiple shifts, `axis` cannot be 1." + ) + if len(periods) == 0: + raise ValueError("If `periods` is an iterable, it cannot be empty.") + from pandas.core.reshape.concat import concat + + shifted_dataframes = [] + for period in periods: + if not is_integer(period): + raise TypeError( + f"Periods must be integer, but {period} is {type(period)}." + ) + period = cast(int, period) + shifted_dataframes.append( + super() + .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) + .add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") + ) + return concat(shifted_dataframes, axis=1) + elif suffix: + raise ValueError("Cannot specify `suffix` if `periods` is an int.") + periods = cast(int, periods) + ncols = len(self.columns) arrays = self._mgr.arrays if axis == 1 and periods != 0 and ncols > 0 and freq is None: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 9d09665b15be9..6e85d79afdbf8 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -10545,11 +10545,12 @@ def mask( @doc(klass=_shared_doc_kwargs["klass"]) def shift( self, - periods: int = 1, + periods: int | Sequence[int] = 1, freq=None, axis: Axis = 0, fill_value: Hashable = lib.no_default, - ) -> Self: + suffix: str | None = None, + ) -> Self | DataFrame: """ Shift index by desired number of periods with an optional time `freq`. @@ -10562,8 +10563,13 @@ def shift( Parameters ---------- - periods : int + periods : int or Sequence Number of periods to shift. Can be positive or negative. + If an iterable of ints, the data will be shifted once by each int. + This is equivalent to shifting by one value at a time and + concatenating all resulting frames. The resulting columns will have + the shift suffixed to their column names. For multiple periods, + axis must not be 1. freq : DateOffset, tseries.offsets, timedelta, or str, optional Offset to use from the tseries module or time rule (e.g. 'EOM'). If `freq` is specified then the index values are shifted but the @@ -10580,6 +10586,9 @@ def shift( For numeric data, ``np.nan`` is used. For datetime, timedelta, or period data, etc. :attr:`NaT` is used. For extension dtypes, ``self.dtype.na_value`` is used. + suffix : str, optional + If str and periods is an iterable, this is added after the column + name and before the shift value for each shifted column name. Returns ------- @@ -10645,6 +10654,14 @@ def shift( 2020-01-06 15 18 22 2020-01-07 30 33 37 2020-01-08 45 48 52 + + >>> df['Col1'].shift(periods=[0, 1, 2]) + Col1_0 Col1_1 Col1_2 + 2020-01-01 10 NaN NaN + 2020-01-02 20 10.0 NaN + 2020-01-03 15 20.0 10.0 + 2020-01-04 30 15.0 20.0 + 2020-01-05 45 30.0 15.0 """ axis = self._get_axis_number(axis) @@ -10658,6 +10675,12 @@ def shift( if periods == 0: return self.copy(deep=None) + if is_list_like(periods) and isinstance(self, ABCSeries): + return self.to_frame().shift( + periods=periods, freq=freq, axis=axis, fill_value=fill_value + ) + periods = cast(int, periods) + if freq is None: # when freq is None, data is shifted, index is not axis = self._get_axis_number(axis) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8d8302826462e..81fadf718f85e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4932,10 +4932,11 @@ def cummax( @Substitution(name="groupby") def shift( self, - periods: int = 1, + periods: int | Sequence[int] = 1, freq=None, axis: Axis | lib.NoDefault = lib.no_default, fill_value=lib.no_default, + suffix: str | None = None, ): """ Shift each group by periods observations. @@ -4944,8 +4945,9 @@ def shift( Parameters ---------- - periods : int, default 1 - Number of periods to shift. + periods : int | Sequence[int], default 1 + Number of periods to shift. If a list of values, shift each group by + each period. freq : str, optional Frequency string. axis : axis to shift, default 0 @@ -4961,6 +4963,10 @@ def shift( .. versionchanged:: 2.1.0 Will raise a ``ValueError`` if ``freq`` is provided too. + suffix : str, optional + A string to add to each shifted column if there are multiple periods. + Ignored otherwise. + Returns ------- Series or DataFrame @@ -5014,25 +5020,70 @@ def shift( else: axis = 0 - if freq is not None or axis != 0: - f = lambda x: x.shift(periods, freq, axis, fill_value) - return self._python_apply_general(f, self._selected_obj, is_transform=True) + if is_list_like(periods): + if axis == 1: + raise ValueError( + "If `periods` contains multiple shifts, `axis` cannot be 1." + ) + periods = cast(Sequence, periods) + if len(periods) == 0: + raise ValueError("If `periods` is an iterable, it cannot be empty.") + from pandas.core.reshape.concat import concat - if fill_value is lib.no_default: - fill_value = None - ids, _, ngroups = self.grouper.group_info - res_indexer = np.zeros(len(ids), dtype=np.int64) + add_suffix = True + else: + if not is_integer(periods): + raise TypeError( + f"Periods must be integer, but {periods} is {type(periods)}." + ) + if suffix: + raise ValueError("Cannot specify `suffix` if `periods` is an int.") + periods = [cast(int, periods)] + add_suffix = False + + shifted_dataframes = [] + for period in periods: + if not is_integer(period): + raise TypeError( + f"Periods must be integer, but {period} is {type(period)}." + ) + period = cast(int, period) + if freq is not None or axis != 0: + f = lambda x: x.shift( + period, freq, axis, fill_value # pylint: disable=cell-var-from-loop + ) + shifted = self._python_apply_general( + f, self._selected_obj, is_transform=True + ) + else: + if fill_value is lib.no_default: + fill_value = None + ids, _, ngroups = self.grouper.group_info + res_indexer = np.zeros(len(ids), dtype=np.int64) - libgroupby.group_shift_indexer(res_indexer, ids, ngroups, periods) + libgroupby.group_shift_indexer(res_indexer, ids, ngroups, period) - obj = self._obj_with_exclusions + obj = self._obj_with_exclusions - res = obj._reindex_with_indexers( - {self.axis: (obj.axes[self.axis], res_indexer)}, - fill_value=fill_value, - allow_dups=True, + shifted = obj._reindex_with_indexers( + {self.axis: (obj.axes[self.axis], res_indexer)}, + fill_value=fill_value, + allow_dups=True, + ) + + if add_suffix: + if isinstance(shifted, Series): + shifted = cast(NDFrameT, shifted.to_frame()) + shifted = shifted.add_suffix( + f"{suffix}_{period}" if suffix else f"_{period}" + ) + shifted_dataframes.append(cast(Union[Series, DataFrame], shifted)) + + return ( + shifted_dataframes[0] + if len(shifted_dataframes) == 1 + else concat(shifted_dataframes, axis=1) ) - return res @final @Substitution(name="groupby") diff --git a/pandas/core/series.py b/pandas/core/series.py index 81eea98ab3f65..677d53c7b6459 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -3028,7 +3028,7 @@ def autocorr(self, lag: int = 1) -> float: >>> s.autocorr() nan """ - return self.corr(self.shift(lag)) + return self.corr(cast(Series, self.shift(lag))) def dot(self, other: AnyArrayLike) -> Series | np.ndarray: """ diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index c024cb0387ff2..d2281ad3a7480 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -665,3 +665,81 @@ def test_shift_with_offsets_freq(self): index=date_range(start="02/01/2000", end="02/01/2000", periods=3), ) tm.assert_frame_equal(shifted, expected) + + def test_shift_with_iterable_basic_functionality(self): + # GH#44424 + data = {"a": [1, 2, 3], "b": [4, 5, 6]} + shifts = [0, 1, 2] + + df = DataFrame(data) + shifted = df.shift(shifts) + + expected = DataFrame( + { + "a_0": [1, 2, 3], + "b_0": [4, 5, 6], + "a_1": [np.NaN, 1.0, 2.0], + "b_1": [np.NaN, 4.0, 5.0], + "a_2": [np.NaN, np.NaN, 1.0], + "b_2": [np.NaN, np.NaN, 4.0], + } + ) + tm.assert_frame_equal(expected, shifted) + + def test_shift_with_iterable_series(self): + # GH#44424 + data = {"a": [1, 2, 3]} + shifts = [0, 1, 2] + + df = DataFrame(data) + s: Series = df["a"] + tm.assert_frame_equal(s.shift(shifts), df.shift(shifts)) + + def test_shift_with_iterable_freq_and_fill_value(self): + # GH#44424 + df = DataFrame( + np.random.randn(5), index=date_range("1/1/2000", periods=5, freq="H") + ) + + tm.assert_frame_equal( + # rename because shift with an iterable leads to str column names + df.shift([1], fill_value=1).rename(columns=lambda x: int(x[0])), + df.shift(1, fill_value=1), + ) + + tm.assert_frame_equal( + df.shift([1], freq="H").rename(columns=lambda x: int(x[0])), + df.shift(1, freq="H"), + ) + + msg = r"Cannot pass both 'freq' and 'fill_value' to.*" + with pytest.raises(ValueError, match=msg): + df.shift([1, 2], fill_value=1, freq="H") + + def test_shift_with_iterable_check_other_arguments(self): + # GH#44424 + data = {"a": [1, 2], "b": [4, 5]} + shifts = [0, 1] + df = DataFrame(data) + + # test suffix + shifted = df[["a"]].shift(shifts, suffix="_suffix") + expected = DataFrame({"a_suffix_0": [1, 2], "a_suffix_1": [np.nan, 1.0]}) + tm.assert_frame_equal(shifted, expected) + + # check bad inputs when doing multiple shifts + msg = "If `periods` contains multiple shifts, `axis` cannot be 1." + with pytest.raises(ValueError, match=msg): + df.shift(shifts, axis=1) + + msg = "Periods must be integer, but s is ." + with pytest.raises(TypeError, match=msg): + df.shift(["s"]) + + msg = "If `periods` is an iterable, it cannot be empty." + with pytest.raises(ValueError, match=msg): + df.shift([]) + + msg = "Cannot specify `suffix` if `periods` is an int." + with pytest.raises(ValueError, match=msg): + df.shift(1, suffix="fails") diff --git a/pandas/tests/groupby/test_groupby_shift_diff.py b/pandas/tests/groupby/test_groupby_shift_diff.py index cec8ea9d351cf..495f3fcd359c7 100644 --- a/pandas/tests/groupby/test_groupby_shift_diff.py +++ b/pandas/tests/groupby/test_groupby_shift_diff.py @@ -173,3 +173,77 @@ def test_shift_disallow_freq_and_fill_value(): msg = "Cannot pass both 'freq' and 'fill_value' to (Series|DataFrame).shift" with pytest.raises(ValueError, match=msg): df.groupby(df.index).shift(periods=-2, freq="D", fill_value="1") + + +def test_shift_disallow_suffix_if_periods_is_int(): + # GH#44424 + data = {"a": [1, 2, 3, 4, 5, 6], "b": [0, 0, 0, 1, 1, 1]} + df = DataFrame(data) + msg = "Cannot specify `suffix` if `periods` is an int." + with pytest.raises(ValueError, match=msg): + df.groupby("b").shift(1, suffix="fails") + + +def test_group_shift_with_multiple_periods(): + # GH#44424 + df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]}) + + shifted_df = df.groupby("b")[["a"]].shift([0, 1]) + expected_df = DataFrame( + {"a_0": [1, 2, 3, 3, 2], "a_1": [np.nan, 1.0, np.nan, 3.0, 2.0]} + ) + tm.assert_frame_equal(shifted_df, expected_df) + + # series + shifted_series = df.groupby("b")["a"].shift([0, 1]) + tm.assert_frame_equal(shifted_series, expected_df) + + +def test_group_shift_with_multiple_periods_and_freq(): + # GH#44424 + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + index=date_range("1/1/2000", periods=5, freq="H"), + ) + shifted_df = df.groupby("b")[["a"]].shift( + [0, 1], + freq="H", + ) + expected_df = DataFrame( + { + "a_0": [1.0, 2.0, 3.0, 4.0, 5.0, np.nan], + "a_1": [ + np.nan, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + ], + }, + index=date_range("1/1/2000", periods=6, freq="H"), + ) + tm.assert_frame_equal(shifted_df, expected_df) + + +def test_group_shift_with_multiple_periods_and_fill_value(): + # GH#44424 + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + ) + shifted_df = df.groupby("b")[["a"]].shift([0, 1], fill_value=-1) + expected_df = DataFrame( + {"a_0": [1, 2, 3, 4, 5], "a_1": [-1, 1, -1, 3, 2]}, + ) + tm.assert_frame_equal(shifted_df, expected_df) + + +def test_group_shift_with_multiple_periods_and_both_fill_and_freq_fails(): + # GH#44424 + df = DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [True, True, False, False, True]}, + index=date_range("1/1/2000", periods=5, freq="H"), + ) + msg = r"Cannot pass both 'freq' and 'fill_value' to.*" + with pytest.raises(ValueError, match=msg): + df.groupby("b")[["a"]].shift([1, 2], fill_value=1, freq="H")