diff --git a/pandas/core/base.py b/pandas/core/base.py index d4a808f4d7dd1..3710a644c7826 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -77,6 +77,7 @@ ) from pandas import ( + DataFrame, Index, Series, ) @@ -254,6 +255,21 @@ def _gotitem(self, key, ndim: int, subset=None): """ raise AbstractMethodError(self) + @final + def _infer_selection(self, key, subset: Series | DataFrame): + """ + Infer the `selection` to pass to our constructor in _gotitem. + """ + # Shared by Rolling and Resample + selection = None + if subset.ndim == 2 and ( + (lib.is_scalar(key) and key in subset) or lib.is_list_like(key) + ): + selection = key + elif subset.ndim == 1 and lib.is_scalar(key) and key == subset.name: + selection = key + return selection + def aggregate(self, func, *args, **kwargs): raise AbstractMethodError(self) diff --git a/pandas/core/resample.py b/pandas/core/resample.py index 8d3ff10ba91b3..9566a2f113b36 100644 --- a/pandas/core/resample.py +++ b/pandas/core/resample.py @@ -42,7 +42,10 @@ import pandas.core.algorithms as algos from pandas.core.apply import ResamplerWindowApply -from pandas.core.base import PandasObject +from pandas.core.base import ( + PandasObject, + SelectionMixin, +) import pandas.core.common as com from pandas.core.generic import ( NDFrame, @@ -1293,7 +1296,7 @@ def quantile(self, q: float | AnyArrayLike = 0.5, **kwargs): return self._downsample("quantile", q=q, **kwargs) -class _GroupByMixin(PandasObject): +class _GroupByMixin(PandasObject, SelectionMixin): """ Provide the groupby facilities. """ @@ -1385,13 +1388,7 @@ def _gotitem(self, key, ndim, subset=None): except IndexError: groupby = self._groupby - selection = None - if subset.ndim == 2 and ( - (lib.is_scalar(key) and key in subset) or lib.is_list_like(key) - ): - selection = key - elif subset.ndim == 1 and lib.is_scalar(key) and key == subset.name: - selection = key + selection = self._infer_selection(key, subset) new_rs = type(self)( groupby=groupby, diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index a08ffcc9f7200..f4d733423b3ae 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -37,9 +37,7 @@ ensure_float64, is_bool, is_integer, - is_list_like, is_numeric_dtype, - is_scalar, needs_i8_conversion, ) from pandas.core.dtypes.generic import ( @@ -302,14 +300,7 @@ def _gotitem(self, key, ndim, subset=None): # with the same groupby kwargs = {attr: getattr(self, attr) for attr in self._attributes} - selection = None - if subset.ndim == 2 and ( - (is_scalar(key) and key in subset) or is_list_like(key) - ): - selection = key - elif subset.ndim == 1 and is_scalar(key) and key == subset.name: - selection = key - + selection = self._infer_selection(key, subset) new_win = type(self)(subset, selection=selection, **kwargs) return new_win