Skip to content

Commit c61318d

Browse files
committed
BUG: groupby with categorical and other columns
closes #14942
1 parent 60fe82c commit c61318d

File tree

6 files changed

+549
-348
lines changed

6 files changed

+549
-348
lines changed

doc/source/whatsnew/v0.23.0.txt

+35
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,41 @@ If you wish to retain the old behavior while using Python >= 3.6, you can use
527527
'Taxes': -200,
528528
'Net result': 300}).sort_index()
529529

530+
.. _whatsnew_0230.api_breaking.categorical_grouping:
531+
532+
Categorical Groupers will now require passing the observed keyword
533+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
534+
535+
In previous versions, grouping by 1 or more categorical columns would result in an index that was the cartesian product of all of the categories for
536+
each grouper, not just the observed values.``.groupby()`` has gained the ``observed`` keyword to toggle this behavior. The default remains backward
537+
compatible (generate a cartesian product). Pandas will show a ``FutureWarning`` if the ``observed`` keyword is not passed; the default will
538+
change to ``observed=True`` in the future. (:issue:`14942`, :issue:`8138`, :issue:`15217`, :issue:`17594`, :issue:`8669`, :issue:`20583`)
539+
540+
541+
.. ipython:: python
542+
543+
cat1 = pd.Categorical(["a", "a", "b", "b"],
544+
categories=["a", "b", "z"], ordered=True)
545+
cat2 = pd.Categorical(["c", "d", "c", "d"],
546+
categories=["c", "d", "y"], ordered=True)
547+
df = pd.DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]})
548+
df['C'] = ['foo', 'bar'] * 2
549+
df
550+
551+
Previous Behavior (show all values):
552+
553+
.. ipython:: python
554+
555+
.. code-block:: python
556+
df.groupby(['A', 'B', 'C'], observed=False).count()
557+
558+
559+
New Behavior (show only observed values):
560+
561+
.. ipython:: python
562+
563+
df.groupby(['A', 'B', 'C'], observed=True).count()
564+
530565
.. _whatsnew_0230.api_breaking.deprecate_panel:
531566

532567
Deprecate Panel

pandas/core/generic.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -6599,7 +6599,7 @@ def clip_lower(self, threshold, axis=None, inplace=False):
65996599
axis=axis, inplace=inplace)
66006600

66016601
def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True,
6602-
group_keys=True, squeeze=False, **kwargs):
6602+
group_keys=True, squeeze=False, observed=None, **kwargs):
66036603
"""
66046604
Group series using mapper (dict or key function, apply given function
66056605
to group, return result as series) or by a series of columns.
@@ -6632,6 +6632,13 @@ def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True,
66326632
squeeze : boolean, default False
66336633
reduce the dimensionality of the return type if possible,
66346634
otherwise return a consistent type
6635+
observed : boolean, default None
6636+
if True: only show observed values for categorical groupers
6637+
if False: show all values for categorical groupers
6638+
if None: if any categorical groupers, show a FutureWarning,
6639+
default to False
6640+
6641+
.. versionadded:: 0.23.0
66356642
66366643
Returns
66376644
-------
@@ -6665,7 +6672,7 @@ def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True,
66656672
axis = self._get_axis_number(axis)
66666673
return groupby(self, by=by, axis=axis, level=level, as_index=as_index,
66676674
sort=sort, group_keys=group_keys, squeeze=squeeze,
6668-
**kwargs)
6675+
observed=observed, **kwargs)
66696676

66706677
def asfreq(self, freq, method=None, how=None, normalize=False,
66716678
fill_value=None):

pandas/core/groupby/groupby.py

+56-17
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,8 @@ class _GroupBy(PandasObject, SelectionMixin):
556556

557557
def __init__(self, obj, keys=None, axis=0, level=None,
558558
grouper=None, exclusions=None, selection=None, as_index=True,
559-
sort=True, group_keys=True, squeeze=False, **kwargs):
559+
sort=True, group_keys=True, squeeze=False,
560+
observed=None, **kwargs):
560561

561562
self._selection = selection
562563

@@ -576,13 +577,15 @@ def __init__(self, obj, keys=None, axis=0, level=None,
576577
self.sort = sort
577578
self.group_keys = group_keys
578579
self.squeeze = squeeze
580+
self.observed = observed
579581
self.mutated = kwargs.pop('mutated', False)
580582

581583
if grouper is None:
582584
grouper, exclusions, obj = _get_grouper(obj, keys,
583585
axis=axis,
584586
level=level,
585587
sort=sort,
588+
observed=observed,
586589
mutated=self.mutated)
587590

588591
self.obj = obj
@@ -2331,18 +2334,21 @@ def ngroups(self):
23312334
def recons_labels(self):
23322335
comp_ids, obs_ids, _ = self.group_info
23332336
labels = (ping.labels for ping in self.groupings)
2334-
return decons_obs_group_ids(comp_ids,
2335-
obs_ids, self.shape, labels, xnull=True)
2337+
return decons_obs_group_ids(
2338+
comp_ids, obs_ids, self.shape, labels, xnull=True)
23362339

23372340
@cache_readonly
23382341
def result_index(self):
23392342
if not self.compressed and len(self.groupings) == 1:
23402343
return self.groupings[0].group_index.rename(self.names[0])
23412344

2342-
return MultiIndex(levels=[ping.group_index for ping in self.groupings],
2343-
labels=self.recons_labels,
2344-
verify_integrity=False,
2345-
names=self.names)
2345+
labels = self.recons_labels
2346+
levels = [ping.group_index for ping in self.groupings]
2347+
result = MultiIndex(levels=levels,
2348+
labels=labels,
2349+
verify_integrity=False,
2350+
names=self.names)
2351+
return result
23462352

23472353
def get_group_levels(self):
23482354
if not self.compressed and len(self.groupings) == 1:
@@ -2883,6 +2889,7 @@ class Grouping(object):
28832889
obj :
28842890
name :
28852891
level :
2892+
observed : If we are a Categorical, use the observed values
28862893
in_axis : if the Grouping is a column in self.obj and hence among
28872894
Groupby.exclusions list
28882895
@@ -2898,14 +2905,15 @@ class Grouping(object):
28982905
"""
28992906

29002907
def __init__(self, index, grouper=None, obj=None, name=None, level=None,
2901-
sort=True, in_axis=False):
2908+
sort=True, observed=None, in_axis=False):
29022909

29032910
self.name = name
29042911
self.level = level
29052912
self.grouper = _convert_grouper(index, grouper)
29062913
self.index = index
29072914
self.sort = sort
29082915
self.obj = obj
2916+
self.observed = observed
29092917
self.in_axis = in_axis
29102918

29112919
# right place for this?
@@ -2954,16 +2962,34 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,
29542962
elif is_categorical_dtype(self.grouper):
29552963

29562964
self.grouper = self.grouper._codes_for_groupby(self.sort)
2965+
codes = self.grouper.codes
2966+
categories = self.grouper.categories
29572967

29582968
# we make a CategoricalIndex out of the cat grouper
29592969
# preserving the categories / ordered attributes
2960-
self._labels = self.grouper.codes
2970+
self._labels = codes
2971+
2972+
# Use the observed values of the grouper if inidcated
2973+
observed = self.observed
2974+
if observed is None:
2975+
msg = ("pass observed=True to ensure that a "
2976+
"categorical grouper only returns the "
2977+
"observed groupers, or\n"
2978+
"observed=False to return NA for non-observed"
2979+
"values\n")
2980+
warnings.warn(msg, FutureWarning, stacklevel=5)
2981+
observed = False
2982+
2983+
if observed:
2984+
codes = algorithms.unique1d(codes)
2985+
else:
2986+
codes = np.arange(len(categories))
29612987

2962-
c = self.grouper.categories
29632988
self._group_index = CategoricalIndex(
2964-
Categorical.from_codes(np.arange(len(c)),
2965-
categories=c,
2966-
ordered=self.grouper.ordered))
2989+
Categorical.from_codes(
2990+
codes=codes,
2991+
categories=categories,
2992+
ordered=self.grouper.ordered))
29672993

29682994
# we are done
29692995
if isinstance(self.grouper, Grouping):
@@ -3048,7 +3074,7 @@ def groups(self):
30483074

30493075

30503076
def _get_grouper(obj, key=None, axis=0, level=None, sort=True,
3051-
mutated=False, validate=True):
3077+
observed=None, mutated=False, validate=True):
30523078
"""
30533079
create and return a BaseGrouper, which is an internal
30543080
mapping of how to create the grouper indexers.
@@ -3065,6 +3091,9 @@ def _get_grouper(obj, key=None, axis=0, level=None, sort=True,
30653091
are and then creates a Grouping for each one, combined into
30663092
a BaseGrouper.
30673093
3094+
If observed & we have a categorical grouper, only show the observed
3095+
values
3096+
30683097
If validate, then check for key/level overlaps
30693098
30703099
"""
@@ -3243,6 +3272,7 @@ def is_in_obj(gpr):
32433272
name=name,
32443273
level=level,
32453274
sort=sort,
3275+
observed=observed,
32463276
in_axis=in_axis) \
32473277
if not isinstance(gpr, Grouping) else gpr
32483278

@@ -4154,7 +4184,7 @@ def first_not_none(values):
41544184
not_indexed_same=not_indexed_same)
41554185
elif self.grouper.groupings is not None:
41564186
if len(self.grouper.groupings) > 1:
4157-
key_index = MultiIndex.from_tuples(keys, names=key_names)
4187+
key_index = self.grouper.result_index
41584188

41594189
else:
41604190
ping = self.grouper.groupings[0]
@@ -4244,8 +4274,9 @@ def first_not_none(values):
42444274

42454275
# normally use vstack as its faster than concat
42464276
# and if we have mi-columns
4247-
if isinstance(v.index,
4248-
MultiIndex) or key_index is None:
4277+
if (isinstance(v.index, MultiIndex) or
4278+
key_index is None or
4279+
isinstance(key_index, MultiIndex)):
42494280
stacked_values = np.vstack(map(np.asarray, values))
42504281
result = DataFrame(stacked_values, index=key_index,
42514282
columns=index)
@@ -4696,6 +4727,14 @@ def _reindex_output(self, result):
46964727
46974728
This can re-expand the output space
46984729
"""
4730+
4731+
# TODO(jreback): remove completely
4732+
# when observed parameter is defaulted to True
4733+
# gh-20583
4734+
4735+
if self.observed:
4736+
return result
4737+
46994738
groupings = self.grouper.groupings
47004739
if groupings is None:
47014740
return result

pandas/core/reshape/pivot.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
7979
pass
8080
values = list(values)
8181

82-
grouped = data.groupby(keys)
82+
grouped = data.groupby(keys, observed=dropna)
8383
agged = grouped.agg(aggfunc)
8484

8585
table = agged
@@ -241,10 +241,13 @@ def _all_key(key):
241241
return (key, margins_name) + ('',) * (len(cols) - 1)
242242

243243
if len(rows) > 0:
244-
margin = data[rows + values].groupby(rows).agg(aggfunc)
244+
margin = data[rows + values].groupby(
245+
rows, observed=True).agg(aggfunc)
245246
cat_axis = 1
246247

247-
for key, piece in table.groupby(level=0, axis=cat_axis):
248+
for key, piece in table.groupby(level=0,
249+
axis=cat_axis,
250+
observed=True):
248251
all_key = _all_key(key)
249252

250253
# we are going to mutate this, so need to copy!
@@ -264,7 +267,9 @@ def _all_key(key):
264267
else:
265268
margin = grand_margin
266269
cat_axis = 0
267-
for key, piece in table.groupby(level=0, axis=cat_axis):
270+
for key, piece in table.groupby(level=0,
271+
axis=cat_axis,
272+
observed=True):
268273
all_key = _all_key(key)
269274
table_pieces.append(piece)
270275
table_pieces.append(Series(margin[key], index=[all_key]))
@@ -279,7 +284,8 @@ def _all_key(key):
279284
margin_keys = table.columns
280285

281286
if len(cols) > 0:
282-
row_margin = data[cols + values].groupby(cols).agg(aggfunc)
287+
row_margin = data[cols + values].groupby(
288+
cols, observed=True).agg(aggfunc)
283289
row_margin = row_margin.stack()
284290

285291
# slight hack
@@ -304,14 +310,17 @@ def _all_key():
304310
return (margins_name, ) + ('', ) * (len(cols) - 1)
305311

306312
if len(rows) > 0:
307-
margin = data[rows].groupby(rows).apply(aggfunc)
313+
margin = data[rows].groupby(rows,
314+
observed=True).apply(aggfunc)
308315
all_key = _all_key()
309316
table[all_key] = margin
310317
result = table
311318
margin_keys.append(all_key)
312319

313320
else:
314-
margin = data.groupby(level=0, axis=0).apply(aggfunc)
321+
margin = data.groupby(level=0,
322+
axis=0,
323+
observed=True).apply(aggfunc)
315324
all_key = _all_key()
316325
table[all_key] = margin
317326
result = table
@@ -322,7 +331,7 @@ def _all_key():
322331
margin_keys = table.columns
323332

324333
if len(cols):
325-
row_margin = data[cols].groupby(cols).apply(aggfunc)
334+
row_margin = data[cols].groupby(cols, observed=True).apply(aggfunc)
326335
else:
327336
row_margin = Series(np.nan, index=result.columns)
328337

0 commit comments

Comments
 (0)