Skip to content

Commit 65466f0

Browse files
benjaminrjreback
authored andcommitted
pivot_table very slow on Categorical data; how about an observed keyword argument? #24923 (#24953)
1 parent e464a88 commit 65466f0

File tree

5 files changed

+51
-9
lines changed

5 files changed

+51
-9
lines changed

asv_bench/benchmarks/reshape.py

+12
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def setup(self):
127127
'value1': np.random.randn(N),
128128
'value2': np.random.randn(N),
129129
'value3': np.random.randn(N)})
130+
self.df2 = DataFrame({'col1': list('abcde'), 'col2': list('fghij'),
131+
'col3': [1, 2, 3, 4, 5]})
132+
self.df2.col1 = self.df2.col1.astype('category')
133+
self.df2.col2 = self.df2.col2.astype('category')
130134

131135
def time_pivot_table(self):
132136
self.df.pivot_table(index='key1', columns=['key2', 'key3'])
@@ -139,6 +143,14 @@ def time_pivot_table_margins(self):
139143
self.df.pivot_table(index='key1', columns=['key2', 'key3'],
140144
margins=True)
141145

146+
def time_pivot_table_categorical(self):
147+
self.df2.pivot_table(index='col1', values='col3', columns='col2',
148+
aggfunc=np.sum, fill_value=0)
149+
150+
def time_pivot_table_categorical_observed(self):
151+
self.df2.pivot_table(index='col1', values='col3', columns='col2',
152+
aggfunc=np.sum, fill_value=0, observed=True)
153+
142154

143155
class Crosstab:
144156

doc/source/whatsnew/v0.25.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Other Enhancements
2828
- Indexing of ``DataFrame`` and ``Series`` now accepts zerodim ``np.ndarray`` (:issue:`24919`)
2929
- :meth:`Timestamp.replace` now supports the ``fold`` argument to disambiguate DST transition times (:issue:`25017`)
3030
- :meth:`DataFrame.at_time` and :meth:`Series.at_time` now support :meth:`datetime.time` objects with timezones (:issue:`24043`)
31+
- :meth:`DataFrame.pivot_table` now accepts an ``observed`` parameter which is passed to underlying calls to :meth:`DataFrame.groupby` to speed up grouping categorical data. (:issue:`24923`)
3132
- ``Series.str`` has gained :meth:`Series.str.casefold` method to removes all case distinctions present in a string (:issue:`25405`)
3233
- :meth:`DataFrame.set_index` now works for instances of ``abc.Iterator``, provided their output is of the same length as the calling frame (:issue:`22484`, :issue:`24984`)
3334
- :meth:`DatetimeIndex.union` now supports the ``sort`` argument. The behaviour of the sort parameter matches that of :meth:`Index.union` (:issue:`24994`)

pandas/core/frame.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -5695,6 +5695,12 @@ def pivot(self, index=None, columns=None, values=None):
56955695
margins_name : string, default 'All'
56965696
Name of the row / column that will contain the totals
56975697
when margins is True.
5698+
observed : boolean, default False
5699+
This only applies if any of the groupers are Categoricals.
5700+
If True: only show observed values for categorical groupers.
5701+
If False: show all values for categorical groupers.
5702+
5703+
.. versionchanged :: 0.25.0
56985704
56995705
Returns
57005706
-------
@@ -5785,12 +5791,12 @@ def pivot(self, index=None, columns=None, values=None):
57855791
@Appender(_shared_docs['pivot_table'])
57865792
def pivot_table(self, values=None, index=None, columns=None,
57875793
aggfunc='mean', fill_value=None, margins=False,
5788-
dropna=True, margins_name='All'):
5794+
dropna=True, margins_name='All', observed=False):
57895795
from pandas.core.reshape.pivot import pivot_table
57905796
return pivot_table(self, values=values, index=index, columns=columns,
57915797
aggfunc=aggfunc, fill_value=fill_value,
57925798
margins=margins, dropna=dropna,
5793-
margins_name=margins_name)
5799+
margins_name=margins_name, observed=observed)
57945800

57955801
def stack(self, level=-1, dropna=True):
57965802
"""

pandas/core/reshape/pivot.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
@Appender(_shared_docs['pivot_table'], indents=1)
2323
def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
2424
fill_value=None, margins=False, dropna=True,
25-
margins_name='All'):
25+
margins_name='All', observed=False):
2626
index = _convert_by(index)
2727
columns = _convert_by(columns)
2828

@@ -34,7 +34,8 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
3434
columns=columns,
3535
fill_value=fill_value, aggfunc=func,
3636
margins=margins, dropna=dropna,
37-
margins_name=margins_name)
37+
margins_name=margins_name,
38+
observed=observed)
3839
pieces.append(table)
3940
keys.append(getattr(func, '__name__', func))
4041

@@ -77,7 +78,7 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
7778
pass
7879
values = list(values)
7980

80-
grouped = data.groupby(keys, observed=False)
81+
grouped = data.groupby(keys, observed=observed)
8182
agged = grouped.agg(aggfunc)
8283
if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
8384
agged = agged.dropna(how='all')

pandas/tests/reshape/test_pivot.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,18 @@ def setup_method(self, method):
3737
'E': np.random.randn(11),
3838
'F': np.random.randn(11)})
3939

40-
def test_pivot_table(self):
40+
def test_pivot_table(self, observed):
4141
index = ['A', 'B']
4242
columns = 'C'
4343
table = pivot_table(self.data, values='D',
44-
index=index, columns=columns)
44+
index=index, columns=columns, observed=observed)
4545

4646
table2 = self.data.pivot_table(
47-
values='D', index=index, columns=columns)
47+
values='D', index=index, columns=columns, observed=observed)
4848
tm.assert_frame_equal(table, table2)
4949

5050
# this works
51-
pivot_table(self.data, values='D', index=index)
51+
pivot_table(self.data, values='D', index=index, observed=observed)
5252

5353
if len(index) > 1:
5454
assert table.index.names == tuple(index)
@@ -64,6 +64,28 @@ def test_pivot_table(self):
6464
index + [columns])['D'].agg(np.mean).unstack()
6565
tm.assert_frame_equal(table, expected)
6666

67+
def test_pivot_table_categorical_observed_equal(self, observed):
68+
# issue #24923
69+
df = pd.DataFrame({'col1': list('abcde'),
70+
'col2': list('fghij'),
71+
'col3': [1, 2, 3, 4, 5]})
72+
73+
expected = df.pivot_table(index='col1', values='col3',
74+
columns='col2', aggfunc=np.sum,
75+
fill_value=0)
76+
77+
expected.index = expected.index.astype('category')
78+
expected.columns = expected.columns.astype('category')
79+
80+
df.col1 = df.col1.astype('category')
81+
df.col2 = df.col2.astype('category')
82+
83+
result = df.pivot_table(index='col1', values='col3',
84+
columns='col2', aggfunc=np.sum,
85+
fill_value=0, observed=observed)
86+
87+
tm.assert_frame_equal(result, expected)
88+
6789
def test_pivot_table_nocols(self):
6890
df = DataFrame({'rows': ['a', 'b', 'c'],
6991
'cols': ['x', 'y', 'z'],

0 commit comments

Comments
 (0)