Skip to content

Commit 9bd15db

Browse files
authored
REF: remove _get_grouper, make Grouper.__init__ less stateful (#51155)
* REF: inline BaseGrouper/BinGrouper._get_grouper * REF: set grouping_vector once at the end of __init__ * REF: make passed_categorical a cache_readonly * mypy fixup
1 parent ad03e49 commit 9bd15db

File tree

2 files changed

+35
-51
lines changed

2 files changed

+35
-51
lines changed

pandas/core/groupby/grouper.py

+35-33
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,6 @@ class Grouping:
464464

465465
_codes: npt.NDArray[np.signedinteger] | None = None
466466
_group_index: Index | None = None
467-
_passed_categorical: bool
468467
_all_grouper: Categorical | None
469468
_orig_cats: Index | None
470469
_index: Index
@@ -483,7 +482,7 @@ def __init__(
483482
) -> None:
484483
self.level = level
485484
self._orig_grouper = grouper
486-
self.grouping_vector = _convert_grouper(index, grouper)
485+
grouping_vector = _convert_grouper(index, grouper)
487486
self._all_grouper = None
488487
self._orig_cats = None
489488
self._index = index
@@ -494,8 +493,6 @@ def __init__(
494493
self._dropna = dropna
495494
self._uniques = uniques
496495

497-
self._passed_categorical = False
498-
499496
# we have a single grouper which may be a myriad of things,
500497
# some of which are dependent on the passing in level
501498

@@ -509,78 +506,83 @@ def __init__(
509506
else:
510507
index_level = index
511508

512-
if self.grouping_vector is None:
513-
self.grouping_vector = index_level
509+
if grouping_vector is None:
510+
grouping_vector = index_level
514511
else:
515-
mapper = self.grouping_vector
516-
self.grouping_vector = index_level.map(mapper)
512+
mapper = grouping_vector
513+
grouping_vector = index_level.map(mapper)
517514

518515
# a passed Grouper like, directly get the grouper in the same way
519516
# as single grouper groupby, use the group_info to get codes
520-
elif isinstance(self.grouping_vector, Grouper):
517+
elif isinstance(grouping_vector, Grouper):
521518
# get the new grouper; we already have disambiguated
522519
# what key/level refer to exactly, don't need to
523520
# check again as we have by this point converted these
524521
# to an actual value (rather than a pd.Grouper)
525522
assert self.obj is not None # for mypy
526-
newgrouper, newobj = self.grouping_vector._get_grouper(
527-
self.obj, validate=False
528-
)
523+
newgrouper, newobj = grouping_vector._get_grouper(self.obj, validate=False)
529524
self.obj = newobj
530525

531-
ng = newgrouper._get_grouper()
532526
if isinstance(newgrouper, ops.BinGrouper):
533-
# in this case we have `ng is newgrouper`
534-
self.grouping_vector = ng
527+
# TODO: can we unwrap this and get a tighter typing
528+
# for self.grouping_vector?
529+
grouping_vector = newgrouper
535530
else:
536531
# ops.BaseGrouper
532+
# TODO: 2023-02-03 no test cases with len(newgrouper.groupings) > 1.
533+
# If that were to occur, would we be throwing out information?
534+
# error: Cannot determine type of "grouping_vector" [has-type]
535+
ng = newgrouper.groupings[0].grouping_vector # type: ignore[has-type]
537536
# use Index instead of ndarray so we can recover the name
538-
self.grouping_vector = Index(ng, name=newgrouper.result_index.name)
537+
grouping_vector = Index(ng, name=newgrouper.result_index.name)
539538

540539
elif not isinstance(
541-
self.grouping_vector, (Series, Index, ExtensionArray, np.ndarray)
540+
grouping_vector, (Series, Index, ExtensionArray, np.ndarray)
542541
):
543542
# no level passed
544-
if getattr(self.grouping_vector, "ndim", 1) != 1:
545-
t = self.name or str(type(self.grouping_vector))
543+
if getattr(grouping_vector, "ndim", 1) != 1:
544+
t = str(type(grouping_vector))
546545
raise ValueError(f"Grouper for '{t}' not 1-dimensional")
547546

548-
self.grouping_vector = index.map(self.grouping_vector)
547+
grouping_vector = index.map(grouping_vector)
549548

550549
if not (
551-
hasattr(self.grouping_vector, "__len__")
552-
and len(self.grouping_vector) == len(index)
550+
hasattr(grouping_vector, "__len__")
551+
and len(grouping_vector) == len(index)
553552
):
554-
grper = pprint_thing(self.grouping_vector)
553+
grper = pprint_thing(grouping_vector)
555554
errmsg = (
556555
"Grouper result violates len(labels) == "
557556
f"len(data)\nresult: {grper}"
558557
)
559-
self.grouping_vector = None # Try for sanity
560558
raise AssertionError(errmsg)
561559

562-
if isinstance(self.grouping_vector, np.ndarray):
563-
if self.grouping_vector.dtype.kind in ["m", "M"]:
560+
if isinstance(grouping_vector, np.ndarray):
561+
if grouping_vector.dtype.kind in ["m", "M"]:
564562
# if we have a date/time-like grouper, make sure that we have
565563
# Timestamps like
566564
# TODO 2022-10-08 we only have one test that gets here and
567565
# values are already in nanoseconds in that case.
568-
self.grouping_vector = Series(self.grouping_vector).to_numpy()
569-
elif is_categorical_dtype(self.grouping_vector):
566+
grouping_vector = Series(grouping_vector).to_numpy()
567+
elif is_categorical_dtype(grouping_vector):
570568
# a passed Categorical
571-
self._passed_categorical = True
572-
573-
self._orig_cats = self.grouping_vector.categories
574-
self.grouping_vector, self._all_grouper = recode_for_groupby(
575-
self.grouping_vector, sort, observed
569+
self._orig_cats = grouping_vector.categories
570+
grouping_vector, self._all_grouper = recode_for_groupby(
571+
grouping_vector, sort, observed
576572
)
577573

574+
self.grouping_vector = grouping_vector
575+
578576
def __repr__(self) -> str:
579577
return f"Grouping({self.name})"
580578

581579
def __iter__(self) -> Iterator:
582580
return iter(self.indices)
583581

582+
@cache_readonly
583+
def _passed_categorical(self) -> bool:
584+
return is_categorical_dtype(self.grouping_vector)
585+
584586
@cache_readonly
585587
def name(self) -> Hashable:
586588
ilevel = self._ilevel

pandas/core/groupby/ops.py

-18
Original file line numberDiff line numberDiff line change
@@ -745,15 +745,6 @@ def _get_splitter(self, data: NDFrame, axis: AxisInt = 0) -> DataSplitter:
745745
ids, _, ngroups = self.group_info
746746
return get_splitter(data, ids, ngroups, axis=axis)
747747

748-
def _get_grouper(self):
749-
"""
750-
We are a grouper as part of another's groupings.
751-
752-
We have a specific method of grouping, so cannot
753-
convert to a Index for our grouper.
754-
"""
755-
return self.groupings[0].grouping_vector
756-
757748
@final
758749
@cache_readonly
759750
def group_keys_seq(self):
@@ -1110,15 +1101,6 @@ def nkeys(self) -> int:
11101101
# still matches len(self.groupings), but we can hard-code
11111102
return 1
11121103

1113-
def _get_grouper(self):
1114-
"""
1115-
We are a grouper as part of another's groupings.
1116-
1117-
We have a specific method of grouping, so cannot
1118-
convert to a Index for our grouper.
1119-
"""
1120-
return self
1121-
11221104
def get_iterator(self, data: NDFrame, axis: AxisInt = 0):
11231105
"""
11241106
Groupby iterator

0 commit comments

Comments
 (0)