Skip to content

Commit 5481324

Browse files
committed
REF: Move Categorical logic from Grouping (#46203)
1 parent e4162cd commit 5481324

File tree

2 files changed

+72
-44
lines changed

2 files changed

+72
-44
lines changed

pandas/core/groupby/categorical.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,68 @@
11
from __future__ import annotations
2+
import dataclasses
3+
from typing import final
24

35
import numpy as np
46

7+
from pandas._typing import ArrayLike
58
from pandas.core.algorithms import unique1d
69
from pandas.core.arrays.categorical import (
710
Categorical,
811
CategoricalDtype,
912
recode_for_categories,
1013
)
11-
from pandas.core.indexes.api import CategoricalIndex
14+
from pandas.core.indexes.api import (
15+
Index,
16+
CategoricalIndex,
17+
)
18+
19+
20+
@final
21+
@dataclasses.dataclass
22+
class CategoricalGrouper:
23+
original_grouping_vector: Categorical
24+
new_grouping_vector: Categorical
25+
observed: bool
26+
sort: bool
27+
28+
def result_index(self, new_index: Index) -> Index:
29+
if self.original_grouping_vector is None:
30+
return new_index
31+
assert isinstance(new_index, CategoricalIndex)
32+
return recode_from_groupby(self.original_grouping_vector, self.sort, new_index)
33+
34+
def group_index(self):
35+
return self.original_grouping_vector
36+
37+
@classmethod
38+
def make(cls, grouping_vector, sort: bool, observed: bool) -> "CategoricalGrouper":
39+
new_grouping_vector, original_grouping_vector = recode_for_groupby(
40+
grouping_vector, sort, observed
41+
)
42+
return cls(
43+
original_grouping_vector=original_grouping_vector,
44+
new_grouping_vector=new_grouping_vector,
45+
observed=observed,
46+
sort=sort,
47+
)
48+
49+
def codes_and_uniques(self, cat: Categorical) -> tuple[np.ndarray, ArrayLike]:
50+
# we make a CategoricalIndex out of the cat grouper
51+
# preserving the categories / ordered attributes
52+
categories = cat.categories
53+
54+
if self.observed:
55+
ucodes = unique1d(cat.codes)
56+
ucodes = ucodes[ucodes != -1]
57+
if self.sort or cat.ordered:
58+
ucodes = np.sort(ucodes)
59+
else:
60+
ucodes = np.arange(len(categories))
61+
62+
uniques = Categorical.from_codes(
63+
codes=ucodes, categories=categories, ordered=cat.ordered
64+
)
65+
return cat.codes, uniques
1266

1367

1468
def recode_for_groupby(

pandas/core/groupby/grouper.py

+17-43
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,8 @@
3838
import pandas.core.common as com
3939
from pandas.core.frame import DataFrame
4040
from pandas.core.groupby import ops
41-
from pandas.core.groupby.categorical import (
42-
recode_for_groupby,
43-
recode_from_groupby,
44-
)
41+
from pandas.core.groupby.categorical import CategoricalGrouper
4542
from pandas.core.indexes.api import (
46-
CategoricalIndex,
4743
Index,
4844
MultiIndex,
4945
)
@@ -461,8 +457,7 @@ class Grouping:
461457

462458
_codes: npt.NDArray[np.signedinteger] | None = None
463459
_group_index: Index | None = None
464-
_passed_categorical: bool
465-
_all_grouper: Categorical | None
460+
_cat_info: CategoricalGrouper | None = None
466461
_index: Index
467462

468463
def __init__(
@@ -479,16 +474,12 @@ def __init__(
479474
self.level = level
480475
self._orig_grouper = grouper
481476
self.grouping_vector = _convert_grouper(index, grouper)
482-
self._all_grouper = None
483477
self._index = index
484478
self._sort = sort
485479
self.obj = obj
486-
self._observed = observed
487480
self.in_axis = in_axis
488481
self._dropna = dropna
489482

490-
self._passed_categorical = False
491-
492483
# we have a single grouper which may be a myriad of things,
493484
# some of which are dependent on the passing in level
494485

@@ -527,13 +518,10 @@ def __init__(
527518
self.grouping_vector = Index(ng, name=newgrouper.result_index.name)
528519

529520
elif is_categorical_dtype(self.grouping_vector):
530-
# a passed Categorical
531-
self._passed_categorical = True
532-
533-
self.grouping_vector, self._all_grouper = recode_for_groupby(
534-
self.grouping_vector, sort, observed
521+
self._cat_info = CategoricalGrouper.make(
522+
self.grouping_vector, sort, observed, dropna=dropna
535523
)
536-
524+
self.grouping_vector = self._cat_info.new_grouping_vector
537525
elif not isinstance(
538526
self.grouping_vector, (Series, Index, ExtensionArray, np.ndarray)
539527
):
@@ -631,20 +619,23 @@ def group_arraylike(self) -> ArrayLike:
631619
# _group_index is set in __init__ for MultiIndex cases
632620
return self._group_index._values
633621

634-
elif self._all_grouper is not None:
622+
elif (
623+
self._cat_info is not None
624+
and self._cat_info.original_grouping_vector is not None
625+
):
635626
# retain dtype for categories, including unobserved ones
636627
return self.result_index._values
637628

638629
return self._codes_and_uniques[1]
639630

640631
@cache_readonly
641632
def result_index(self) -> Index:
642-
# result_index retains dtype for categories, including unobserved ones,
643-
# which group_index does not
644-
if self._all_grouper is not None:
645-
group_idx = self.group_index
646-
assert isinstance(group_idx, CategoricalIndex)
647-
return recode_from_groupby(self._all_grouper, self._sort, group_idx)
633+
"""
634+
result_index retains dtype for categories, including unobserved ones,
635+
which group_index does not
636+
"""
637+
if self._cat_info is not None:
638+
return self._cat_info.result_index(self.group_index)
648639
return self.group_index
649640

650641
@cache_readonly
@@ -658,25 +649,8 @@ def group_index(self) -> Index:
658649

659650
@cache_readonly
660651
def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
661-
if self._passed_categorical:
662-
# we make a CategoricalIndex out of the cat grouper
663-
# preserving the categories / ordered attributes
664-
cat = self.grouping_vector
665-
categories = cat.categories
666-
667-
if self._observed:
668-
ucodes = algorithms.unique1d(cat.codes)
669-
ucodes = ucodes[ucodes != -1]
670-
if self._sort or cat.ordered:
671-
ucodes = np.sort(ucodes)
672-
else:
673-
ucodes = np.arange(len(categories))
674-
675-
uniques = Categorical.from_codes(
676-
codes=ucodes, categories=categories, ordered=cat.ordered
677-
)
678-
return cat.codes, uniques
679-
652+
if self._cat_info is not None:
653+
return self._cat_info.codes_and_uniques(self.grouping_vector)
680654
elif isinstance(self.grouping_vector, ops.BaseGrouper):
681655
# we have a list of groupers
682656
codes = self.grouping_vector.codes_info

0 commit comments

Comments
 (0)