Skip to content

Commit bba56bf

Browse files
committed
BUG: pivot_table always returns a DataFrame
Before this commit, if * `values` is not list like * `columns` is `None` * `aggfunc` is not instance of `list` `pivot_table` returns a `Series`. This commit adds checking for `columns.nlevels` is greater than 1 to prevent from casting `table` to a `Series`. This will fix #4386.
1 parent 0bf4532 commit bba56bf

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,4 @@ Bug Fixes
771771
- Bug in ``pd.melt()`` where passing a tuple value for ``value_vars`` caused a ``TypeError`` (:issue:`15348`)
772772
- Bug in ``.eval()`` which caused multiline evals to fail with local variables not on the first line (:issue:`15342`)
773773
- Bug in ``pd.read_msgpack`` which did not allow to load dataframe with an index of type ``CategoricalIndex`` (:issue:`15487`)
774+
- Bug in ``pivot_table`` returns ``Series`` in specific circumstance (:issue:`4386`)

pandas/tests/tools/test_pivot.py

+38
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,44 @@ def test_categorical_pivot_index_ordering(self):
939939
columns=expected_columns)
940940
tm.assert_frame_equal(result, expected)
941941

942+
def test_pivot_table_not_series(self):
943+
# GH 4386
944+
# pivot_table always returns a DataFrame
945+
# when values is not list like and columns is None
946+
# and aggfunc is not instance of list
947+
df = DataFrame({'col1': [3, 4, 5],
948+
'col2': ['C', 'D', 'E'],
949+
'col3': [1, 3, 9]})
950+
951+
result = df.pivot_table('col1', index=['col3', 'col2'], aggfunc=np.sum)
952+
m = MultiIndex.from_arrays([[1, 3, 9],
953+
['C', 'D', 'E']],
954+
names=['col3', 'col2'])
955+
expected = DataFrame([3, 4, 5],
956+
index=m, columns=['col1'])
957+
958+
tm.assert_frame_equal(result, expected)
959+
960+
result = df.pivot_table(
961+
'col1', index='col3', columns='col2', aggfunc=np.sum
962+
)
963+
expected = DataFrame([[3, np.NaN, np.NaN],
964+
[np.NaN, 4, np.NaN],
965+
[np.NaN, np.NaN, 5]],
966+
index=Index([1, 3, 9], name='col3'),
967+
columns=Index(['C', 'D', 'E'], name='col2'))
968+
969+
tm.assert_frame_equal(result, expected)
970+
971+
result = df.pivot_table('col1', index='col3', aggfunc=[np.sum])
972+
m = MultiIndex.from_arrays([['sum'],
973+
['col1']])
974+
expected = DataFrame([3, 4, 5],
975+
index=Index([1, 3, 9], name='col3'),
976+
columns=m)
977+
978+
tm.assert_frame_equal(result, expected)
979+
942980

943981
class TestCrosstab(tm.TestCase):
944982

pandas/tools/pivot.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
169169
margins_name=margins_name)
170170

171171
# discard the top level
172-
if values_passed and not values_multi and not table.empty:
172+
if values_passed and not values_multi and not table.empty and \
173+
(table.columns.nlevels > 1):
173174
table = table[values[0]]
174175

175176
if len(index) == 0 and len(columns) > 0:

0 commit comments

Comments
 (0)