Skip to content

Commit 422ebc3

Browse files
johnzangwillgasparitiago
authored andcommitted
Move group_selection_context from module to class (pandas-dev#43796)
1 parent d6d8e3b commit 422ebc3

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

pandas/core/groupby/generic.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
_agg_template,
7676
_apply_docs,
7777
_transform_template,
78-
group_selection_context,
7978
warn_dropping_nuisance_columns_deprecated,
8079
)
8180
from pandas.core.indexes.api import (
@@ -243,7 +242,7 @@ def apply(self, func, *args, **kwargs):
243242
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
244243

245244
if maybe_use_numba(engine):
246-
with group_selection_context(self):
245+
with self._group_selection_context():
247246
data = self._selected_obj
248247
result = self._aggregate_with_numba(
249248
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
@@ -845,7 +844,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
845844
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
846845

847846
if maybe_use_numba(engine):
848-
with group_selection_context(self):
847+
with self._group_selection_context():
849848
data = self._selected_obj
850849
result = self._aggregate_with_numba(
851850
data, func, *args, engine_kwargs=engine_kwargs, **kwargs

pandas/core/groupby/groupby.py

+20-21
Original file line numberDiff line numberDiff line change
@@ -546,18 +546,6 @@ def f(self):
546546
return attr
547547

548548

549-
@contextmanager
550-
def group_selection_context(groupby: GroupBy) -> Iterator[GroupBy]:
551-
"""
552-
Set / reset the group_selection_context.
553-
"""
554-
groupby._set_group_selection()
555-
try:
556-
yield groupby
557-
finally:
558-
groupby._reset_group_selection()
559-
560-
561549
_KeysArgType = Union[
562550
Hashable,
563551
List[Hashable],
@@ -915,7 +903,7 @@ def __getattr__(self, attr: str):
915903
def _make_wrapper(self, name: str) -> Callable:
916904
assert name in self._apply_allowlist
917905

918-
with group_selection_context(self):
906+
with self._group_selection_context():
919907
# need to setup the selection
920908
# as are not passed directly but in the grouper
921909
f = getattr(self._obj_with_exclusions, name)
@@ -992,6 +980,17 @@ def _reset_group_selection(self) -> None:
992980
self._group_selection = None
993981
self._reset_cache("_selected_obj")
994982

983+
@contextmanager
984+
def _group_selection_context(self) -> Iterator[GroupBy]:
985+
"""
986+
Set / reset the _group_selection_context.
987+
"""
988+
self._set_group_selection()
989+
try:
990+
yield self
991+
finally:
992+
self._reset_group_selection()
993+
995994
def _iterate_slices(self) -> Iterable[Series]:
996995
raise AbstractMethodError(self)
997996

@@ -1365,7 +1364,7 @@ def f(g):
13651364
# fails on *some* columns, e.g. a numeric operation
13661365
# on a string grouper column
13671366

1368-
with group_selection_context(self):
1367+
with self._group_selection_context():
13691368
return self._python_apply_general(f, self._selected_obj)
13701369

13711370
return result
@@ -1445,7 +1444,7 @@ def _agg_general(
14451444
npfunc: Callable,
14461445
):
14471446

1448-
with group_selection_context(self):
1447+
with self._group_selection_context():
14491448
# try a cython aggregation if we can
14501449
result = self._cython_agg_general(
14511450
how=alias,
@@ -1552,7 +1551,7 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
15521551

15531552
if maybe_use_numba(engine):
15541553
# TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy
1555-
with group_selection_context(self):
1554+
with self._group_selection_context():
15561555
data = self._selected_obj
15571556
df = data if data.ndim == 2 else data.to_frame()
15581557
result = self._transform_with_numba(
@@ -1954,7 +1953,7 @@ def var(self, ddof: int = 1):
19541953
)
19551954
else:
19561955
func = lambda x: x.var(ddof=ddof)
1957-
with group_selection_context(self):
1956+
with self._group_selection_context():
19581957
return self._python_agg_general(func)
19591958

19601959
@final
@@ -2146,7 +2145,7 @@ def ohlc(self) -> DataFrame:
21462145

21472146
@doc(DataFrame.describe)
21482147
def describe(self, **kwargs):
2149-
with group_selection_context(self):
2148+
with self._group_selection_context():
21502149
result = self.apply(lambda x: x.describe(**kwargs))
21512150
if self.axis == 1:
21522151
return result.T
@@ -2530,7 +2529,7 @@ def nth(
25302529
nth_values = list(set(n))
25312530

25322531
nth_array = np.array(nth_values, dtype=np.intp)
2533-
with group_selection_context(self):
2532+
with self._group_selection_context():
25342533

25352534
mask_left = np.in1d(self._cumcount_array(), nth_array)
25362535
mask_right = np.in1d(
@@ -2827,7 +2826,7 @@ def ngroup(self, ascending: bool = True):
28272826
5 0
28282827
dtype: int64
28292828
"""
2830-
with group_selection_context(self):
2829+
with self._group_selection_context():
28312830
index = self._selected_obj.index
28322831
result = self._obj_1d_constructor(
28332832
self.grouper.group_info[0], index, dtype=np.int64
@@ -2891,7 +2890,7 @@ def cumcount(self, ascending: bool = True):
28912890
5 0
28922891
dtype: int64
28932892
"""
2894-
with group_selection_context(self):
2893+
with self._group_selection_context():
28952894
index = self._selected_obj._get_axis(self.axis)
28962895
cumcounts = self._cumcount_array(ascending=ascending)
28972896
return self._obj_1d_constructor(cumcounts, index)

0 commit comments

Comments
 (0)