Skip to content

Commit 773d861

Browse files
committed
ENH: unstack multiple columns in one shot to eliminate empty columns in pivot table operations, close #1181
1 parent 094e5e4 commit 773d861

File tree

6 files changed

+144
-6
lines changed

6 files changed

+144
-6
lines changed

pandas/core/reshape.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from pandas.core.frame import DataFrame
1010

1111
from pandas.core.common import notnull, _ensure_platform_int
12-
from pandas.core.groupby import get_group_index
12+
from pandas.core.groupby import (get_group_index, _compress_group_index,
13+
decons_group_index)
14+
15+
1316
from pandas.core.index import MultiIndex
1417

1518

@@ -198,6 +201,71 @@ def get_new_index(self):
198201
return new_index
199202

200203

204+
def _unstack_multiple(data, clocs):
205+
if len(clocs) == 0:
206+
return data
207+
208+
# NOTE: This doesn't deal with hierarchical columns yet
209+
210+
index = data.index
211+
212+
clevels, rlevels = _partition(index.levels, clocs)
213+
clabels, rlabels = _partition(index.labels, clocs)
214+
cnames, rnames = _partition(index.names, clocs)
215+
216+
shape = [len(x) for x in clevels]
217+
group_index = get_group_index(clabels, shape)
218+
219+
comp_ids, obs_ids = _compress_group_index(group_index, sort=False)
220+
221+
dummy_index = MultiIndex(levels=rlevels + [obs_ids],
222+
labels=rlabels + [comp_ids],
223+
names=rnames + ['__placeholder__'])
224+
225+
dummy = DataFrame(data.values, index=dummy_index,
226+
columns=data.columns)
227+
228+
unstacked = dummy.unstack('__placeholder__')
229+
230+
if isinstance(unstacked, Series):
231+
unstcols = unstacked.index
232+
else:
233+
unstcols = unstacked.columns
234+
235+
new_levels = [unstcols.levels[0]] + clevels
236+
new_names = [data.columns.name] + cnames
237+
238+
recons_labels = decons_group_index(obs_ids, shape)
239+
240+
new_labels = [unstcols.labels[0]]
241+
for rec in recons_labels:
242+
new_labels.append(rec.take(unstcols.labels[-1]))
243+
244+
new_columns = MultiIndex(levels=new_levels, labels=new_labels,
245+
names=new_names)
246+
247+
if isinstance(unstacked, Series):
248+
unstacked.index = new_columns
249+
else:
250+
unstacked.columns = new_columns
251+
252+
return unstacked
253+
254+
255+
def _partition(values, inds):
256+
left = []
257+
right = []
258+
259+
set_inds = set(inds)
260+
261+
for i, val in enumerate(values):
262+
if i in set_inds:
263+
left.append(val)
264+
else:
265+
right.append(val)
266+
267+
return left, right
268+
201269

202270
def pivot(self, index=None, columns=None, values=None):
203271
"""

pandas/tests/test_multilevel.py

+10
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,16 @@ def test_unstack(self):
607607
# test that ints work
608608
unstacked = self.ymd.astype(int).unstack()
609609

610+
# def test_unstack_multiple_no_empty_columns(self):
611+
# index = MultiIndex.from_tuples([(0, 'foo', 0), (0, 'bar', 0),
612+
# (1, 'baz', 1), (1, 'qux', 1)])
613+
614+
# s = Series(np.random.randn(4), index=index)
615+
616+
# unstacked = s.unstack([1, 2])
617+
# expected = unstacked.dropna(axis=1, how='all')
618+
# assert_frame_equal(unstacked, expected)
619+
610620
def test_stack(self):
611621
# regular roundtrip
612622
unstacked = self.ymd.unstack()

pandas/tools/pivot.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# pylint: disable=E1103
22

33
from pandas import Series, DataFrame
4+
from pandas.core.reshape import _unstack_multiple
45
from pandas.tools.merge import concat
56
import pandas.core.common as com
67
import numpy as np
7-
import types
8+
89

910
def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
1011
fill_value=None, margins=False):
@@ -97,10 +98,12 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
9798
grouped = data.groupby(keys)
9899
agged = grouped.agg(aggfunc)
99100

100-
table = agged
101-
for i in range(len(cols)):
102-
name = table.index.names[len(rows)]
103-
table = table.unstack(name)
101+
table = _unstack_multiple(agged, range(len(rows), len(keys)))
102+
103+
# table = agged
104+
# for i in range(len(cols)):
105+
# name = table.index.names[len(rows)]
106+
# table = table.unstack(name)
104107

105108
if fill_value is not None:
106109
table = table.fillna(value=fill_value)
@@ -115,6 +118,7 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
115118

116119
return table
117120

121+
118122
DataFrame.pivot_table = pivot_table
119123

120124
def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):

pandas/tools/tests/test_pivot.py

+15
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,21 @@ def test_pivot_integer_columns(self):
157157

158158
tm.assert_frame_equal(table, table2)
159159

160+
def test_pivot_no_level_overlap(self):
161+
# GH #1181
162+
163+
data = DataFrame({'a': ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b'] * 2,
164+
'b': [0, 0, 0, 0, 1, 1, 1, 1] * 2,
165+
'c': (['foo'] * 4 + ['bar'] * 4) * 2,
166+
'value': np.random.randn(16)})
167+
168+
table = data.pivot_table('value', rows='a', cols=['b', 'c'])
169+
170+
grouped = data.groupby(['a', 'b', 'c'])['value'].mean()
171+
expected = grouped.unstack('b').unstack('c').dropna(axis=1, how='all')
172+
tm.assert_frame_equal(table, expected)
173+
174+
160175
class TestCrosstab(unittest.TestCase):
161176

162177
def setUp(self):

vb_suite/groupby.py

+23
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,26 @@ def f():
121121

122122
series_value_counts_int64 = Benchmark('s.value_counts()', setup,
123123
start_date=datetime(2011, 10, 21))
124+
125+
#----------------------------------------------------------------------
126+
# pivot_table
127+
128+
setup = common_setup + """
129+
fac1 = np.array(['A', 'B', 'C'], dtype='O')
130+
fac2 = np.array(['one', 'two'], dtype='O')
131+
132+
ind1 = np.random.randint(0, 3, size=100000)
133+
ind2 = np.random.randint(0, 2, size=100000)
134+
135+
df = DataFrame({'key1': fac1.take(ind1),
136+
'key2': fac2.take(ind2),
137+
'key3': fac2.take(ind2),
138+
'value1' : np.random.randn(100000),
139+
'value2' : np.random.randn(100000),
140+
'value3' : np.random.randn(100000)})
141+
"""
142+
143+
stmt = "df.pivot_table(rows='key1', cols=['key2', 'key3'])"
144+
groupby_pivot_table = Benchmark(stmt, setup, start_date=datetime(2011, 12, 15))
145+
146+

vb_suite/reshape.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from vbench.api import Benchmark
2+
from datetime import datetime
3+
4+
common_setup = """from pandas_vb_common import *
5+
index = MultiIndex.from_arrays([np.arange(100).repeat(100),
6+
np.roll(np.tile(np.arange(100), 100), 25)])
7+
df = DataFrame(np.random.randn(10000, 4), index=index)
8+
"""
9+
10+
reshape_unstack_simple = Benchmark('df.unstack(1)', common_setup,
11+
start_date=datetime(2011, 10, 1))
12+
13+
setup = common_setup + """
14+
udf = df.unstack(1)
15+
"""
16+
17+
reshape_stack_simple = Benchmark('udf.stack()', setup,
18+
start_date=datetime(2011, 10, 1))

0 commit comments

Comments
 (0)