Skip to content

Move group_selection_context from module to class #43796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
_agg_template,
_apply_docs,
_transform_template,
group_selection_context,
warn_dropping_nuisance_columns_deprecated,
)
from pandas.core.indexes.api import (
Expand Down Expand Up @@ -243,7 +242,7 @@ def apply(self, func, *args, **kwargs):
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
with group_selection_context(self):
with self._group_selection_context():
data = self._selected_obj
result = self._aggregate_with_numba(
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
Expand Down Expand Up @@ -845,7 +844,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
with group_selection_context(self):
with self._group_selection_context():
data = self._selected_obj
result = self._aggregate_with_numba(
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
Expand Down
41 changes: 20 additions & 21 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,18 +546,6 @@ def f(self):
return attr


@contextmanager
def group_selection_context(groupby: GroupBy) -> Iterator[GroupBy]:
"""
Set / reset the group_selection_context.
"""
groupby._set_group_selection()
try:
yield groupby
finally:
groupby._reset_group_selection()


_KeysArgType = Union[
Hashable,
List[Hashable],
Expand Down Expand Up @@ -915,7 +903,7 @@ def __getattr__(self, attr: str):
def _make_wrapper(self, name: str) -> Callable:
assert name in self._apply_allowlist

with group_selection_context(self):
with self._group_selection_context():
# need to setup the selection
# as are not passed directly but in the grouper
f = getattr(self._obj_with_exclusions, name)
Expand Down Expand Up @@ -992,6 +980,17 @@ def _reset_group_selection(self) -> None:
self._group_selection = None
self._reset_cache("_selected_obj")

@contextmanager
def _group_selection_context(self) -> Iterator[GroupBy]:
"""
Set / reset the _group_selection_context.
"""
self._set_group_selection()
try:
yield self
finally:
self._reset_group_selection()

def _iterate_slices(self) -> Iterable[Series]:
raise AbstractMethodError(self)

Expand Down Expand Up @@ -1365,7 +1364,7 @@ def f(g):
# fails on *some* columns, e.g. a numeric operation
# on a string grouper column

with group_selection_context(self):
with self._group_selection_context():
return self._python_apply_general(f, self._selected_obj)

return result
Expand Down Expand Up @@ -1445,7 +1444,7 @@ def _agg_general(
npfunc: Callable,
):

with group_selection_context(self):
with self._group_selection_context():
# try a cython aggregation if we can
result = self._cython_agg_general(
how=alias,
Expand Down Expand Up @@ -1552,7 +1551,7 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
# TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy
with group_selection_context(self):
with self._group_selection_context():
data = self._selected_obj
df = data if data.ndim == 2 else data.to_frame()
result = self._transform_with_numba(
Expand Down Expand Up @@ -1954,7 +1953,7 @@ def var(self, ddof: int = 1):
)
else:
func = lambda x: x.var(ddof=ddof)
with group_selection_context(self):
with self._group_selection_context():
return self._python_agg_general(func)

@final
Expand Down Expand Up @@ -2146,7 +2145,7 @@ def ohlc(self) -> DataFrame:

@doc(DataFrame.describe)
def describe(self, **kwargs):
with group_selection_context(self):
with self._group_selection_context():
result = self.apply(lambda x: x.describe(**kwargs))
if self.axis == 1:
return result.T
Expand Down Expand Up @@ -2530,7 +2529,7 @@ def nth(
nth_values = list(set(n))

nth_array = np.array(nth_values, dtype=np.intp)
with group_selection_context(self):
with self._group_selection_context():

mask_left = np.in1d(self._cumcount_array(), nth_array)
mask_right = np.in1d(
Expand Down Expand Up @@ -2827,7 +2826,7 @@ def ngroup(self, ascending: bool = True):
5 0
dtype: int64
"""
with group_selection_context(self):
with self._group_selection_context():
index = self._selected_obj.index
result = self._obj_1d_constructor(
self.grouper.group_info[0], index, dtype=np.int64
Expand Down Expand Up @@ -2891,7 +2890,7 @@ def cumcount(self, ascending: bool = True):
5 0
dtype: int64
"""
with group_selection_context(self):
with self._group_selection_context():
index = self._selected_obj._get_axis(self.axis)
cumcounts = self._cumcount_array(ascending=ascending)
return self._obj_1d_constructor(cumcounts, index)
Expand Down