Skip to content

Commit e7f6b0f

Browse files
committed
REF: Move Categorical logic from Grouping (#46203)
1 parent e4162cd commit e7f6b0f

File tree

2 files changed

+72
-43
lines changed

2 files changed

+72
-43
lines changed

pandas/core/groupby/categorical.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,69 @@
11
from __future__ import annotations
22

3+
import dataclasses
4+
from typing import final
5+
36
import numpy as np
47

8+
from pandas._typing import ArrayLike
9+
510
from pandas.core.algorithms import unique1d
611
from pandas.core.arrays.categorical import (
712
Categorical,
813
CategoricalDtype,
914
recode_for_categories,
1015
)
11-
from pandas.core.indexes.api import CategoricalIndex
16+
from pandas.core.indexes.api import (
17+
CategoricalIndex,
18+
Index,
19+
)
20+
21+
22+
@final
23+
@dataclasses.dataclass
24+
class CategoricalGrouper:
25+
original_grouping_vector: Categorical
26+
new_grouping_vector: Categorical
27+
observed: bool
28+
sort: bool
29+
30+
def result_index(self, group_index: Index) -> Index:
31+
if self.original_grouping_vector is None:
32+
return group_index
33+
assert isinstance(group_index, CategoricalIndex)
34+
return recode_from_groupby(
35+
self.original_grouping_vector, self.sort, group_index
36+
)
37+
38+
@classmethod
39+
def make(cls, grouping_vector, sort: bool, observed: bool) -> CategoricalGrouper:
40+
new_grouping_vector, original_grouping_vector = recode_for_groupby(
41+
grouping_vector, sort, observed
42+
)
43+
return cls(
44+
original_grouping_vector=original_grouping_vector,
45+
new_grouping_vector=new_grouping_vector,
46+
observed=observed,
47+
sort=sort,
48+
)
49+
50+
def codes_and_uniques(self, cat: Categorical) -> tuple[np.ndarray, ArrayLike]:
51+
# we make a CategoricalIndex out of the cat grouper
52+
# preserving the categories / ordered attributes
53+
categories = cat.categories
54+
55+
if self.observed:
56+
ucodes = unique1d(cat.codes)
57+
ucodes = ucodes[ucodes != -1]
58+
if self.sort or cat.ordered:
59+
ucodes = np.sort(ucodes)
60+
else:
61+
ucodes = np.arange(len(categories))
62+
63+
uniques = Categorical.from_codes(
64+
codes=ucodes, categories=categories, ordered=cat.ordered
65+
)
66+
return cat.codes, uniques
1267

1368

1469
def recode_for_groupby(

pandas/core/groupby/grouper.py

+16-42
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,8 @@
3838
import pandas.core.common as com
3939
from pandas.core.frame import DataFrame
4040
from pandas.core.groupby import ops
41-
from pandas.core.groupby.categorical import (
42-
recode_for_groupby,
43-
recode_from_groupby,
44-
)
41+
from pandas.core.groupby.categorical import CategoricalGrouper
4542
from pandas.core.indexes.api import (
46-
CategoricalIndex,
4743
Index,
4844
MultiIndex,
4945
)
@@ -461,8 +457,7 @@ class Grouping:
461457

462458
_codes: npt.NDArray[np.signedinteger] | None = None
463459
_group_index: Index | None = None
464-
_passed_categorical: bool
465-
_all_grouper: Categorical | None
460+
_cat_grouper: CategoricalGrouper | None = None
466461
_index: Index
467462

468463
def __init__(
@@ -479,16 +474,12 @@ def __init__(
479474
self.level = level
480475
self._orig_grouper = grouper
481476
self.grouping_vector = _convert_grouper(index, grouper)
482-
self._all_grouper = None
483477
self._index = index
484478
self._sort = sort
485479
self.obj = obj
486-
self._observed = observed
487480
self.in_axis = in_axis
488481
self._dropna = dropna
489482

490-
self._passed_categorical = False
491-
492483
# we have a single grouper which may be a myriad of things,
493484
# some of which are dependent on the passing in level
494485

@@ -527,13 +518,10 @@ def __init__(
527518
self.grouping_vector = Index(ng, name=newgrouper.result_index.name)
528519

529520
elif is_categorical_dtype(self.grouping_vector):
530-
# a passed Categorical
531-
self._passed_categorical = True
532-
533-
self.grouping_vector, self._all_grouper = recode_for_groupby(
521+
self._cat_grouper = CategoricalGrouper.make(
534522
self.grouping_vector, sort, observed
535523
)
536-
524+
self.grouping_vector = self._cat_grouper.new_grouping_vector
537525
elif not isinstance(
538526
self.grouping_vector, (Series, Index, ExtensionArray, np.ndarray)
539527
):
@@ -631,20 +619,23 @@ def group_arraylike(self) -> ArrayLike:
631619
# _group_index is set in __init__ for MultiIndex cases
632620
return self._group_index._values
633621

634-
elif self._all_grouper is not None:
622+
elif (
623+
self._cat_grouper is not None
624+
and self._cat_grouper.original_grouping_vector is not None
625+
):
635626
# retain dtype for categories, including unobserved ones
636627
return self.result_index._values
637628

638629
return self._codes_and_uniques[1]
639630

640631
@cache_readonly
641632
def result_index(self) -> Index:
642-
# result_index retains dtype for categories, including unobserved ones,
643-
# which group_index does not
644-
if self._all_grouper is not None:
645-
group_idx = self.group_index
646-
assert isinstance(group_idx, CategoricalIndex)
647-
return recode_from_groupby(self._all_grouper, self._sort, group_idx)
633+
"""
634+
result_index retains dtype for categories, including unobserved ones,
635+
which group_index does not
636+
"""
637+
if self._cat_grouper is not None:
638+
return self._cat_grouper.result_index(self.group_index)
648639
return self.group_index
649640

650641
@cache_readonly
@@ -658,25 +649,8 @@ def group_index(self) -> Index:
658649

659650
@cache_readonly
660651
def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
661-
if self._passed_categorical:
662-
# we make a CategoricalIndex out of the cat grouper
663-
# preserving the categories / ordered attributes
664-
cat = self.grouping_vector
665-
categories = cat.categories
666-
667-
if self._observed:
668-
ucodes = algorithms.unique1d(cat.codes)
669-
ucodes = ucodes[ucodes != -1]
670-
if self._sort or cat.ordered:
671-
ucodes = np.sort(ucodes)
672-
else:
673-
ucodes = np.arange(len(categories))
674-
675-
uniques = Categorical.from_codes(
676-
codes=ucodes, categories=categories, ordered=cat.ordered
677-
)
678-
return cat.codes, uniques
679-
652+
if self._cat_grouper is not None:
653+
return self._cat_grouper.codes_and_uniques(self.grouping_vector)
680654
elif isinstance(self.grouping_vector, ops.BaseGrouper):
681655
# we have a list of groupers
682656
codes = self.grouping_vector.codes_info

0 commit comments

Comments
 (0)