Skip to content

Commit 3ce07ea

Browse files
author
Junya Hayashi
committed
ENH: Refactor groupby for Categorical grouper
1 parent b53ef23 commit 3ce07ea

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

pandas/core/groupby.py

+9-26
Original file line numberDiff line numberDiff line change
@@ -1862,7 +1862,6 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,
18621862
self.grouper = grouper.values
18631863

18641864
# pre-computed
1865-
self._grouping_type = None
18661865
self._should_compress = True
18671866

18681867
# we have a single grouper which may be a myriad of things, some of which are
@@ -1887,8 +1886,6 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,
18871886
level_values = index.levels[level].take(inds)
18881887
self.grouper = level_values.map(self.grouper)
18891888
else:
1890-
self._grouping_type = "level"
1891-
18921889
# all levels may not be observed
18931890
labels, uniques = algos.factorize(inds, sort=True)
18941891

@@ -1913,17 +1910,10 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,
19131910

19141911
# a passed Categorical
19151912
elif isinstance(self.grouper, Categorical):
1916-
1917-
factor = self.grouper
1918-
self._grouping_type = "categorical"
1919-
1920-
# Is there any way to avoid this?
1921-
self.grouper = np.asarray(factor)
1922-
1923-
self._labels = factor.codes
1924-
self._group_index = factor.categories
1913+
self._labels = self.grouper.codes
1914+
self._group_index = self.grouper.categories
19251915
if self.name is None:
1926-
self.name = factor.name
1916+
self.name = self.grouper.name
19271917

19281918
# a passed Grouper like
19291919
elif isinstance(self.grouper, Grouper):
@@ -1936,8 +1926,8 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,
19361926
self.name = grouper.name
19371927

19381928
# no level passed
1939-
if not isinstance(self.grouper, (Series, Index, np.ndarray)):
1940-
if getattr(self.grouper,'ndim', 1) != 1:
1929+
if not isinstance(self.grouper, (Series, Index, Categorical, np.ndarray)):
1930+
if getattr(self.grouper, 'ndim', 1) != 1:
19411931
t = self.name or str(type(self.grouper))
19421932
raise ValueError("Grouper for '%s' not 1-dimensional" % t)
19431933
self.grouper = self.index.map(self.grouper)
@@ -1988,22 +1978,15 @@ def group_index(self):
19881978
return self._group_index
19891979

19901980
def _make_labels(self):
1991-
if self._grouping_type in ("level", "categorical"): # pragma: no cover
1992-
raise Exception(
1993-
'Should not call this method grouping by level or categorical')
1994-
else:
1981+
if self._labels is None or self._group_index is None:
19951982
labels, uniques = algos.factorize(self.grouper, sort=self.sort)
19961983
uniques = Index(uniques, name=self.name)
19971984
self._labels = labels
19981985
self._group_index = uniques
19991986

2000-
_groups = None
2001-
2002-
@property
1987+
@cache_readonly
20031988
def groups(self):
2004-
if self._groups is None:
2005-
self._groups = self.index.groupby(self.grouper)
2006-
return self._groups
1989+
return self.index.groupby(self.grouper)
20071990

20081991
def _get_grouper(obj, key=None, axis=0, level=None, sort=True):
20091992
"""
@@ -3239,7 +3222,7 @@ def _reindex_output(self, result):
32393222
return result
32403223
elif len(groupings) == 1:
32413224
return result
3242-
elif not any([ping._grouping_type == "categorical"
3225+
elif not any([isinstance(ping.grouper, Categorical)
32433226
for ping in groupings]):
32443227
return result
32453228

pandas/tests/test_groupby.py

+28
Original file line numberDiff line numberDiff line change
@@ -3297,6 +3297,34 @@ def test_groupby_categorical(self):
32973297
expected.index.names = ['myfactor', None]
32983298
assert_frame_equal(desc_result, expected)
32993299

3300+
def test_groupby_datetime_categorical(self):
3301+
# GH9049: ensure backward compatibility
3302+
levels = pd.date_range('2014-01-01', periods=4)
3303+
codes = np.random.randint(0, 4, size=100)
3304+
3305+
cats = Categorical.from_codes(codes, levels, name='myfactor')
3306+
3307+
data = DataFrame(np.random.randn(100, 4))
3308+
3309+
result = data.groupby(cats).mean()
3310+
3311+
expected = data.groupby(np.asarray(cats)).mean()
3312+
expected = expected.reindex(levels)
3313+
expected.index.name = 'myfactor'
3314+
3315+
assert_frame_equal(result, expected)
3316+
self.assertEqual(result.index.name, cats.name)
3317+
3318+
grouped = data.groupby(cats)
3319+
desc_result = grouped.describe()
3320+
3321+
idx = cats.codes.argsort()
3322+
ord_labels = np.asarray(cats).take(idx)
3323+
ord_data = data.take(idx)
3324+
expected = ord_data.groupby(ord_labels, sort=False).describe()
3325+
expected.index.names = ['myfactor', None]
3326+
assert_frame_equal(desc_result, expected)
3327+
33003328
def test_groupby_groups_datetimeindex(self):
33013329
# #1430
33023330
from pandas.tseries.api import DatetimeIndex

0 commit comments

Comments
 (0)