Skip to content

Commit 23aafc7

Browse files
committed
REF: Move Categorical logic from Grouping (#46203)
1 parent e4162cd commit 23aafc7

File tree

2 files changed

+80
-45
lines changed

2 files changed

+80
-45
lines changed

pandas/core/groupby/categorical.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,72 @@
11
from __future__ import annotations
22

3+
import dataclasses
4+
from typing import final
5+
36
import numpy as np
47

8+
from pandas._typing import npt
9+
510
from pandas.core.algorithms import unique1d
611
from pandas.core.arrays.categorical import (
712
Categorical,
813
CategoricalDtype,
914
recode_for_categories,
1015
)
11-
from pandas.core.indexes.api import CategoricalIndex
16+
from pandas.core.indexes.api import (
17+
CategoricalIndex,
18+
Index,
19+
)
20+
21+
22+
@final
23+
@dataclasses.dataclass
24+
class CategoricalGrouper:
25+
original_grouping_vector: Categorical | None
26+
new_grouping_vector: Categorical
27+
observed: bool
28+
sort: bool
29+
30+
def result_index(self, group_index: Index) -> Index:
31+
if self.original_grouping_vector is None:
32+
return group_index
33+
assert isinstance(group_index, CategoricalIndex)
34+
return recode_from_groupby(
35+
self.original_grouping_vector, self.sort, group_index
36+
)
37+
38+
@classmethod
39+
def make(cls, grouping_vector, sort: bool, observed: bool) -> CategoricalGrouper:
40+
new_grouping_vector, original_grouping_vector = recode_for_groupby(
41+
grouping_vector, sort, observed
42+
)
43+
return cls(
44+
original_grouping_vector=original_grouping_vector,
45+
new_grouping_vector=new_grouping_vector,
46+
observed=observed,
47+
sort=sort,
48+
)
49+
50+
def codes_and_uniques(
51+
self, cat: Categorical
52+
) -> tuple[npt.NDArray[np.signedinteger], Categorical]:
53+
"""
54+
This does a version of `algorithms.factorize()` that works on categoricals.
55+
"""
56+
categories = cat.categories
57+
58+
if self.observed:
59+
ucodes = unique1d(cat.codes)
60+
ucodes = ucodes[ucodes != -1]
61+
if self.sort or cat.ordered:
62+
ucodes = np.sort(ucodes)
63+
else:
64+
ucodes = np.arange(len(categories))
65+
66+
uniques = Categorical.from_codes(
67+
codes=ucodes, categories=categories, ordered=cat.ordered
68+
)
69+
return cat.codes, uniques
1270

1371

1472
def recode_for_groupby(

pandas/core/groupby/grouper.py

+21-44
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
TYPE_CHECKING,
99
Any,
1010
Hashable,
11+
cast,
1112
final,
1213
)
1314
import warnings
1415

1516
import numpy as np
1617

1718
from pandas._typing import (
19+
AnyArrayLike,
1820
ArrayLike,
1921
NDFrameT,
2022
npt,
@@ -38,12 +40,8 @@
3840
import pandas.core.common as com
3941
from pandas.core.frame import DataFrame
4042
from pandas.core.groupby import ops
41-
from pandas.core.groupby.categorical import (
42-
recode_for_groupby,
43-
recode_from_groupby,
44-
)
43+
from pandas.core.groupby.categorical import CategoricalGrouper
4544
from pandas.core.indexes.api import (
46-
CategoricalIndex,
4745
Index,
4846
MultiIndex,
4947
)
@@ -461,8 +459,7 @@ class Grouping:
461459

462460
_codes: npt.NDArray[np.signedinteger] | None = None
463461
_group_index: Index | None = None
464-
_passed_categorical: bool
465-
_all_grouper: Categorical | None
462+
_cat_grouper: CategoricalGrouper | None = None
466463
_index: Index
467464

468465
def __init__(
@@ -479,16 +476,12 @@ def __init__(
479476
self.level = level
480477
self._orig_grouper = grouper
481478
self.grouping_vector = _convert_grouper(index, grouper)
482-
self._all_grouper = None
483479
self._index = index
484480
self._sort = sort
485481
self.obj = obj
486-
self._observed = observed
487482
self.in_axis = in_axis
488483
self._dropna = dropna
489484

490-
self._passed_categorical = False
491-
492485
# we have a single grouper which may be a myriad of things,
493486
# some of which are dependent on the passing in level
494487

@@ -527,13 +520,10 @@ def __init__(
527520
self.grouping_vector = Index(ng, name=newgrouper.result_index.name)
528521

529522
elif is_categorical_dtype(self.grouping_vector):
530-
# a passed Categorical
531-
self._passed_categorical = True
532-
533-
self.grouping_vector, self._all_grouper = recode_for_groupby(
523+
self._cat_grouper = CategoricalGrouper.make(
534524
self.grouping_vector, sort, observed
535525
)
536-
526+
self.grouping_vector = self._cat_grouper.new_grouping_vector
537527
elif not isinstance(
538528
self.grouping_vector, (Series, Index, ExtensionArray, np.ndarray)
539529
):
@@ -631,20 +621,23 @@ def group_arraylike(self) -> ArrayLike:
631621
# _group_index is set in __init__ for MultiIndex cases
632622
return self._group_index._values
633623

634-
elif self._all_grouper is not None:
624+
elif (
625+
self._cat_grouper is not None
626+
and self._cat_grouper.original_grouping_vector is not None
627+
):
635628
# retain dtype for categories, including unobserved ones
636629
return self.result_index._values
637630

638-
return self._codes_and_uniques[1]
631+
return cast(ArrayLike, self._codes_and_uniques[1])
639632

640633
@cache_readonly
641634
def result_index(self) -> Index:
642-
# result_index retains dtype for categories, including unobserved ones,
643-
# which group_index does not
644-
if self._all_grouper is not None:
645-
group_idx = self.group_index
646-
assert isinstance(group_idx, CategoricalIndex)
647-
return recode_from_groupby(self._all_grouper, self._sort, group_idx)
635+
"""
636+
result_index retains dtype for categories, including unobserved ones,
637+
which group_index does not
638+
"""
639+
if self._cat_grouper is not None:
640+
return self._cat_grouper.result_index(self.group_index)
648641
return self.group_index
649642

650643
@cache_readonly
@@ -657,26 +650,10 @@ def group_index(self) -> Index:
657650
return Index._with_infer(uniques, name=self.name)
658651

659652
@cache_readonly
660-
def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
661-
if self._passed_categorical:
662-
# we make a CategoricalIndex out of the cat grouper
663-
# preserving the categories / ordered attributes
664-
cat = self.grouping_vector
665-
categories = cat.categories
666-
667-
if self._observed:
668-
ucodes = algorithms.unique1d(cat.codes)
669-
ucodes = ucodes[ucodes != -1]
670-
if self._sort or cat.ordered:
671-
ucodes = np.sort(ucodes)
672-
else:
673-
ucodes = np.arange(len(categories))
674-
675-
uniques = Categorical.from_codes(
676-
codes=ucodes, categories=categories, ordered=cat.ordered
677-
)
678-
return cat.codes, uniques
679-
653+
def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], AnyArrayLike]:
654+
uniques: AnyArrayLike
655+
if self._cat_grouper is not None:
656+
return self._cat_grouper.codes_and_uniques(self.grouping_vector)
680657
elif isinstance(self.grouping_vector, ops.BaseGrouper):
681658
# we have a list of groupers
682659
codes = self.grouping_vector.codes_info

0 commit comments

Comments
 (0)