From 2c26a8867af383e49af6cc4ad4817a705913069f Mon Sep 17 00:00:00 2001 From: John Zangwill Date: Wed, 29 Sep 2021 09:34:20 +0100 Subject: [PATCH] Move group_selection_context from module to class @jreback request to bring forwards part of #42947 --- pandas/core/groupby/generic.py | 5 ++--- pandas/core/groupby/groupby.py | 41 +++++++++++++++++----------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index c5633b6809676..b4a21ccd2150a 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -75,7 +75,6 @@ _agg_template, _apply_docs, _transform_template, - group_selection_context, warn_dropping_nuisance_columns_deprecated, ) from pandas.core.indexes.api import ( @@ -243,7 +242,7 @@ def apply(self, func, *args, **kwargs): def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - with group_selection_context(self): + with self._group_selection_context(): data = self._selected_obj result = self._aggregate_with_numba( data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs @@ -845,7 +844,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]): def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - with group_selection_context(self): + with self._group_selection_context(): data = self._selected_obj result = self._aggregate_with_numba( data, func, *args, engine_kwargs=engine_kwargs, **kwargs diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 89755ec5f7863..c9e60155d3a06 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -546,18 +546,6 @@ def f(self): return attr -@contextmanager -def group_selection_context(groupby: GroupBy) -> Iterator[GroupBy]: - """ - Set / reset the group_selection_context. - """ - groupby._set_group_selection() - try: - yield groupby - finally: - groupby._reset_group_selection() - - _KeysArgType = Union[ Hashable, List[Hashable], @@ -915,7 +903,7 @@ def __getattr__(self, attr: str): def _make_wrapper(self, name: str) -> Callable: assert name in self._apply_allowlist - with group_selection_context(self): + with self._group_selection_context(): # need to setup the selection # as are not passed directly but in the grouper f = getattr(self._obj_with_exclusions, name) @@ -992,6 +980,17 @@ def _reset_group_selection(self) -> None: self._group_selection = None self._reset_cache("_selected_obj") + @contextmanager + def _group_selection_context(self) -> Iterator[GroupBy]: + """ + Set / reset the _group_selection_context. + """ + self._set_group_selection() + try: + yield self + finally: + self._reset_group_selection() + def _iterate_slices(self) -> Iterable[Series]: raise AbstractMethodError(self) @@ -1365,7 +1364,7 @@ def f(g): # fails on *some* columns, e.g. a numeric operation # on a string grouper column - with group_selection_context(self): + with self._group_selection_context(): return self._python_apply_general(f, self._selected_obj) return result @@ -1445,7 +1444,7 @@ def _agg_general( npfunc: Callable, ): - with group_selection_context(self): + with self._group_selection_context(): # try a cython aggregation if we can result = self._cython_agg_general( how=alias, @@ -1552,7 +1551,7 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): # TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy - with group_selection_context(self): + with self._group_selection_context(): data = self._selected_obj df = data if data.ndim == 2 else data.to_frame() result = self._transform_with_numba( @@ -1954,7 +1953,7 @@ def var(self, ddof: int = 1): ) else: func = lambda x: x.var(ddof=ddof) - with group_selection_context(self): + with self._group_selection_context(): return self._python_agg_general(func) @final @@ -2146,7 +2145,7 @@ def ohlc(self) -> DataFrame: @doc(DataFrame.describe) def describe(self, **kwargs): - with group_selection_context(self): + with self._group_selection_context(): result = self.apply(lambda x: x.describe(**kwargs)) if self.axis == 1: return result.T @@ -2530,7 +2529,7 @@ def nth( nth_values = list(set(n)) nth_array = np.array(nth_values, dtype=np.intp) - with group_selection_context(self): + with self._group_selection_context(): mask_left = np.in1d(self._cumcount_array(), nth_array) mask_right = np.in1d( @@ -2827,7 +2826,7 @@ def ngroup(self, ascending: bool = True): 5 0 dtype: int64 """ - with group_selection_context(self): + with self._group_selection_context(): index = self._selected_obj.index result = self._obj_1d_constructor( self.grouper.group_info[0], index, dtype=np.int64 @@ -2891,7 +2890,7 @@ def cumcount(self, ascending: bool = True): 5 0 dtype: int64 """ - with group_selection_context(self): + with self._group_selection_context(): index = self._selected_obj._get_axis(self.axis) cumcounts = self._cumcount_array(ascending=ascending) return self._obj_1d_constructor(cumcounts, index)