Skip to content

Commit 5f2009a

Browse files
committed
REF: Move Categorical logic from Grouping (#46203)
1 parent e4162cd commit 5f2009a

File tree

2 files changed

+70
-43
lines changed

2 files changed

+70
-43
lines changed

pandas/core/groupby/categorical.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,67 @@
11
from __future__ import annotations
2+
import dataclasses
3+
from typing import final
24

35
import numpy as np
46

7+
from pandas._typing import ArrayLike
58
from pandas.core.algorithms import unique1d
69
from pandas.core.arrays.categorical import (
710
Categorical,
811
CategoricalDtype,
912
recode_for_categories,
1013
)
11-
from pandas.core.indexes.api import CategoricalIndex
14+
from pandas.core.indexes.api import (
15+
Index,
16+
CategoricalIndex,
17+
)
18+
19+
20+
@final
21+
@dataclasses.dataclass
22+
class CategoricalGrouper:
23+
original_grouping_vector: Categorical
24+
new_grouping_vector: Categorical
25+
observed: bool
26+
sort: bool
27+
28+
def result_index(self, group_index: Index) -> Index:
29+
if self.original_grouping_vector is None:
30+
return group_index
31+
assert isinstance(group_index, CategoricalIndex)
32+
return recode_from_groupby(
33+
self.original_grouping_vector, self.sort, group_index
34+
)
35+
36+
@classmethod
37+
def make(cls, grouping_vector, sort: bool, observed: bool) -> "CategoricalGrouper":
38+
new_grouping_vector, original_grouping_vector = recode_for_groupby(
39+
grouping_vector, sort, observed
40+
)
41+
return cls(
42+
original_grouping_vector=original_grouping_vector,
43+
new_grouping_vector=new_grouping_vector,
44+
observed=observed,
45+
sort=sort,
46+
)
47+
48+
def codes_and_uniques(self, cat: Categorical) -> tuple[np.ndarray, ArrayLike]:
49+
# we make a CategoricalIndex out of the cat grouper
50+
# preserving the categories / ordered attributes
51+
categories = cat.categories
52+
53+
if self.observed:
54+
ucodes = unique1d(cat.codes)
55+
ucodes = ucodes[ucodes != -1]
56+
if self.sort or cat.ordered:
57+
ucodes = np.sort(ucodes)
58+
else:
59+
ucodes = np.arange(len(categories))
60+
61+
uniques = Categorical.from_codes(
62+
codes=ucodes, categories=categories, ordered=cat.ordered
63+
)
64+
return cat.codes, uniques
1265

1366

1467
def recode_for_groupby(

pandas/core/groupby/grouper.py

+16-42
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,8 @@
3838
import pandas.core.common as com
3939
from pandas.core.frame import DataFrame
4040
from pandas.core.groupby import ops
41-
from pandas.core.groupby.categorical import (
42-
recode_for_groupby,
43-
recode_from_groupby,
44-
)
41+
from pandas.core.groupby.categorical import CategoricalGrouper
4542
from pandas.core.indexes.api import (
46-
CategoricalIndex,
4743
Index,
4844
MultiIndex,
4945
)
@@ -461,8 +457,7 @@ class Grouping:
461457

462458
_codes: npt.NDArray[np.signedinteger] | None = None
463459
_group_index: Index | None = None
464-
_passed_categorical: bool
465-
_all_grouper: Categorical | None
460+
_cat_grouper: CategoricalGrouper | None = None
466461
_index: Index
467462

468463
def __init__(
@@ -479,16 +474,12 @@ def __init__(
479474
self.level = level
480475
self._orig_grouper = grouper
481476
self.grouping_vector = _convert_grouper(index, grouper)
482-
self._all_grouper = None
483477
self._index = index
484478
self._sort = sort
485479
self.obj = obj
486-
self._observed = observed
487480
self.in_axis = in_axis
488481
self._dropna = dropna
489482

490-
self._passed_categorical = False
491-
492483
# we have a single grouper which may be a myriad of things,
493484
# some of which are dependent on the passing in level
494485

@@ -527,13 +518,10 @@ def __init__(
527518
self.grouping_vector = Index(ng, name=newgrouper.result_index.name)
528519

529520
elif is_categorical_dtype(self.grouping_vector):
530-
# a passed Categorical
531-
self._passed_categorical = True
532-
533-
self.grouping_vector, self._all_grouper = recode_for_groupby(
521+
self._cat_grouper = CategoricalGrouper.make(
534522
self.grouping_vector, sort, observed
535523
)
536-
524+
self.grouping_vector = self._cat_grouper.new_grouping_vector
537525
elif not isinstance(
538526
self.grouping_vector, (Series, Index, ExtensionArray, np.ndarray)
539527
):
@@ -631,20 +619,23 @@ def group_arraylike(self) -> ArrayLike:
631619
# _group_index is set in __init__ for MultiIndex cases
632620
return self._group_index._values
633621

634-
elif self._all_grouper is not None:
622+
elif (
623+
self._cat_grouper is not None
624+
and self._cat_grouper.original_grouping_vector is not None
625+
):
635626
# retain dtype for categories, including unobserved ones
636627
return self.result_index._values
637628

638629
return self._codes_and_uniques[1]
639630

640631
@cache_readonly
641632
def result_index(self) -> Index:
642-
# result_index retains dtype for categories, including unobserved ones,
643-
# which group_index does not
644-
if self._all_grouper is not None:
645-
group_idx = self.group_index
646-
assert isinstance(group_idx, CategoricalIndex)
647-
return recode_from_groupby(self._all_grouper, self._sort, group_idx)
633+
"""
634+
result_index retains dtype for categories, including unobserved ones,
635+
which group_index does not
636+
"""
637+
if self._cat_grouper is not None:
638+
return self._cat_grouper.result_index(self.group_index)
648639
return self.group_index
649640

650641
@cache_readonly
@@ -658,25 +649,8 @@ def group_index(self) -> Index:
658649

659650
@cache_readonly
660651
def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
661-
if self._passed_categorical:
662-
# we make a CategoricalIndex out of the cat grouper
663-
# preserving the categories / ordered attributes
664-
cat = self.grouping_vector
665-
categories = cat.categories
666-
667-
if self._observed:
668-
ucodes = algorithms.unique1d(cat.codes)
669-
ucodes = ucodes[ucodes != -1]
670-
if self._sort or cat.ordered:
671-
ucodes = np.sort(ucodes)
672-
else:
673-
ucodes = np.arange(len(categories))
674-
675-
uniques = Categorical.from_codes(
676-
codes=ucodes, categories=categories, ordered=cat.ordered
677-
)
678-
return cat.codes, uniques
679-
652+
if self._cat_grouper is not None:
653+
return self._cat_grouper.codes_and_uniques(self.grouping_vector)
680654
elif isinstance(self.grouping_vector, ops.BaseGrouper):
681655
# we have a list of groupers
682656
codes = self.grouping_vector.codes_info

0 commit comments

Comments
 (0)