Skip to content

Commit 9e1bd7c

Browse files
authored
TYP: exclusions in BaseGroupBy (#36559)
1 parent 3ee6242 commit 9e1bd7c

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

pandas/core/groupby/groupby.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class providing the base-class of operations.
2424
Mapping,
2525
Optional,
2626
Sequence,
27+
Set,
2728
Tuple,
2829
Type,
2930
TypeVar,
@@ -36,7 +37,7 @@ class providing the base-class of operations.
3637

3738
from pandas._libs import Timestamp, lib
3839
import pandas._libs.groupby as libgroupby
39-
from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Scalar
40+
from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Label, Scalar
4041
from pandas.compat.numpy import function as nv
4142
from pandas.errors import AbstractMethodError
4243
from pandas.util._decorators import Appender, Substitution, cache_readonly, doc
@@ -488,7 +489,7 @@ def __init__(
488489
axis: int = 0,
489490
level=None,
490491
grouper: Optional["ops.BaseGrouper"] = None,
491-
exclusions=None,
492+
exclusions: Optional[Set[Label]] = None,
492493
selection=None,
493494
as_index: bool = True,
494495
sort: bool = True,
@@ -537,7 +538,7 @@ def __init__(
537538
self.obj = obj
538539
self.axis = obj._get_axis_number(axis)
539540
self.grouper = grouper
540-
self.exclusions = set(exclusions) if exclusions else set()
541+
self.exclusions = exclusions or set()
541542

542543
def __len__(self) -> int:
543544
return len(self.groups)

pandas/core/groupby/grouper.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
Provide user facing operators for doing the split part of the
33
split-apply-combine paradigm.
44
"""
5-
from typing import Dict, Hashable, List, Optional, Tuple
5+
from typing import Dict, Hashable, List, Optional, Set, Tuple
66
import warnings
77

88
import numpy as np
99

10-
from pandas._typing import FrameOrSeries
10+
from pandas._typing import FrameOrSeries, Label
1111
from pandas.errors import InvalidIndexError
1212
from pandas.util._decorators import cache_readonly
1313

@@ -614,7 +614,7 @@ def get_grouper(
614614
mutated: bool = False,
615615
validate: bool = True,
616616
dropna: bool = True,
617-
) -> Tuple["ops.BaseGrouper", List[Hashable], FrameOrSeries]:
617+
) -> Tuple["ops.BaseGrouper", Set[Label], FrameOrSeries]:
618618
"""
619619
Create and return a BaseGrouper, which is an internal
620620
mapping of how to create the grouper indexers.
@@ -690,13 +690,13 @@ def get_grouper(
690690
if isinstance(key, Grouper):
691691
binner, grouper, obj = key._get_grouper(obj, validate=False)
692692
if key.key is None:
693-
return grouper, [], obj
693+
return grouper, set(), obj
694694
else:
695-
return grouper, [key.key], obj
695+
return grouper, {key.key}, obj
696696

697697
# already have a BaseGrouper, just return it
698698
elif isinstance(key, ops.BaseGrouper):
699-
return key, [], obj
699+
return key, set(), obj
700700

701701
if not isinstance(key, list):
702702
keys = [key]
@@ -739,7 +739,7 @@ def get_grouper(
739739
levels = [level] * len(keys)
740740

741741
groupings: List[Grouping] = []
742-
exclusions: List[Hashable] = []
742+
exclusions: Set[Label] = set()
743743

744744
# if the actual grouper should be obj[key]
745745
def is_in_axis(key) -> bool:
@@ -769,21 +769,21 @@ def is_in_obj(gpr) -> bool:
769769

770770
if is_in_obj(gpr): # df.groupby(df['name'])
771771
in_axis, name = True, gpr.name
772-
exclusions.append(name)
772+
exclusions.add(name)
773773

774774
elif is_in_axis(gpr): # df.groupby('name')
775775
if gpr in obj:
776776
if validate:
777777
obj._check_label_or_level_ambiguity(gpr, axis=axis)
778778
in_axis, name, gpr = True, gpr, obj[gpr]
779-
exclusions.append(name)
779+
exclusions.add(name)
780780
elif obj._is_level_reference(gpr, axis=axis):
781781
in_axis, name, level, gpr = False, None, gpr, None
782782
else:
783783
raise KeyError(gpr)
784784
elif isinstance(gpr, Grouper) and gpr.key is not None:
785785
# Add key to exclusions
786-
exclusions.append(gpr.key)
786+
exclusions.add(gpr.key)
787787
in_axis, name = False, None
788788
else:
789789
in_axis, name = False, None

0 commit comments

Comments
 (0)