Skip to content

Commit 2d63a71

Browse files
committed
ENH add drop_na argument to pivot_table
1 parent 030f613 commit 2d63a71

File tree

5 files changed

+95
-6
lines changed

5 files changed

+95
-6
lines changed

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pandas 0.12
7272
- support python3 (via ``PyTables 3.0.0``) (:issue:`3750`)
7373
- Add modulo operator to Series, DataFrame
7474
- Add ``date`` method to DatetimeIndex
75+
- Add ``dropna`` argument to pivot_table (:issue: `3820`)
7576
- Simplified the API and added a describe method to Categorical
7677
- ``melt`` now accepts the optional parameters ``var_name`` and ``value_name``
7778
to specify custom column names of the returned DataFrame (:issue:`3649`),

pandas/tools/pivot.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from pandas.core.index import MultiIndex
55
from pandas.core.reshape import _unstack_multiple
66
from pandas.tools.merge import concat
7+
from pandas.tools.util import cartesian_product
78
import pandas.core.common as com
89
import numpy as np
910

1011

1112
def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
12-
fill_value=None, margins=False):
13+
fill_value=None, margins=False, dropna=True):
1314
"""
1415
Create a spreadsheet-style pivot table as a DataFrame. The levels in the
1516
pivot table will be stored in MultiIndex objects (hierarchical indexes) on
@@ -31,6 +32,8 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
3132
Value to replace missing values with
3233
margins : boolean, default False
3334
Add all row / columns (e.g. for subtotal / grand totals)
35+
dropna : boolean, default True
36+
Do not include columns whose entries are all NaN
3437
3538
Examples
3639
--------
@@ -105,6 +108,19 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
105108
for i in range(len(rows), len(keys))]
106109
table = agged.unstack(to_unstack)
107110

111+
if not dropna:
112+
try:
113+
m = MultiIndex.from_arrays(cartesian_product(table.index.levels))
114+
table = table.reindex_axis(m, axis=0)
115+
except AttributeError:
116+
pass # it's a single level
117+
118+
try:
119+
m = MultiIndex.from_arrays(cartesian_product(table.columns.levels))
120+
table = table.reindex_axis(m, axis=1)
121+
except AttributeError:
122+
pass # it's a single level or a series
123+
108124
if isinstance(table, DataFrame):
109125
if isinstance(table.columns, MultiIndex):
110126
table = table.sortlevel(axis=1)
@@ -216,7 +232,7 @@ def _convert_by(by):
216232

217233

218234
def crosstab(rows, cols, values=None, rownames=None, colnames=None,
219-
aggfunc=None, margins=False):
235+
aggfunc=None, margins=False, dropna=True):
220236
"""
221237
Compute a simple cross-tabulation of two (or more) factors. By default
222238
computes a frequency table of the factors unless an array of values and an
@@ -238,6 +254,8 @@ def crosstab(rows, cols, values=None, rownames=None, colnames=None,
238254
If passed, must match number of column arrays passed
239255
margins : boolean, default False
240256
Add row/column margins (subtotals)
257+
dropna : boolean, default True
258+
Do not include columns whose entries are all NaN
241259
242260
Notes
243261
-----
@@ -281,13 +299,13 @@ def crosstab(rows, cols, values=None, rownames=None, colnames=None,
281299
df = DataFrame(data)
282300
df['__dummy__'] = 0
283301
table = df.pivot_table('__dummy__', rows=rownames, cols=colnames,
284-
aggfunc=len, margins=margins)
302+
aggfunc=len, margins=margins, dropna=dropna)
285303
return table.fillna(0).astype(np.int64)
286304
else:
287305
data['__dummy__'] = values
288306
df = DataFrame(data)
289307
table = df.pivot_table('__dummy__', rows=rownames, cols=colnames,
290-
aggfunc=aggfunc, margins=margins)
308+
aggfunc=aggfunc, margins=margins, dropna=dropna)
291309
return table
292310

293311

pandas/tools/tests/test_pivot.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import unittest
22

33
import numpy as np
4+
from numpy.testing import assert_equal
45

5-
from pandas import DataFrame, Series, Index
6+
from pandas import DataFrame, Series, Index, MultiIndex
67
from pandas.tools.merge import concat
78
from pandas.tools.pivot import pivot_table, crosstab
89
import pandas.util.testing as tm
@@ -62,6 +63,22 @@ def test_pivot_table_nocols(self):
6263
xp = df.pivot_table(rows='cols', aggfunc={'values': 'mean'}).T
6364
tm.assert_frame_equal(rs, xp)
6465

66+
def test_pivot_table_dropna(self):
67+
df = DataFrame({'amount': {0: 60000, 1: 100000, 2: 50000, 3: 30000},
68+
'customer': {0: 'A', 1: 'A', 2: 'B', 3: 'C'},
69+
'month': {0: 201307, 1: 201309, 2: 201308, 3: 201310},
70+
'product': {0: 'a', 1: 'b', 2: 'c', 3: 'd'},
71+
'quantity': {0: 2000000, 1: 500000, 2: 1000000, 3: 1000000}})
72+
pv_col = df.pivot_table('quantity', 'month', ['customer', 'product'], dropna=False)
73+
pv_ind = df.pivot_table('quantity', ['customer', 'product'], 'month', dropna=False)
74+
75+
m = MultiIndex.from_tuples([(u'A', u'a'), (u'A', u'b'), (u'A', u'c'), (u'A', u'd'),
76+
(u'B', u'a'), (u'B', u'b'), (u'B', u'c'), (u'B', u'd'),
77+
(u'C', u'a'), (u'C', u'b'), (u'C', u'c'), (u'C', u'd')])
78+
79+
assert_equal(pv_col.columns.values, m.values)
80+
assert_equal(pv_ind.index.values, m.values)
81+
6582

6683
def test_pass_array(self):
6784
result = self.data.pivot_table('D', rows=self.data.A, cols=self.data.C)
@@ -374,6 +391,16 @@ def test_crosstab_pass_values(self):
374391
aggfunc=np.sum)
375392
tm.assert_frame_equal(table, expected)
376393

394+
def test_crosstab_dropna(self):
395+
# GH 3820
396+
a = np.array(['foo', 'foo', 'foo', 'bar', 'bar', 'foo', 'foo'], dtype=object)
397+
b = np.array(['one', 'one', 'two', 'one', 'two', 'two', 'two'], dtype=object)
398+
c = np.array(['dull', 'dull', 'dull', 'dull', 'dull', 'shiny', 'shiny'], dtype=object)
399+
res = crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'], dropna=False)
400+
m = MultiIndex.from_tuples([('one', 'dull'), ('one', 'shiny'),
401+
('two', 'dull'), ('two', 'shiny')])
402+
assert_equal(res.columns.values, m.values)
403+
377404
if __name__ == '__main__':
378405
import nose
379406
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],

pandas/tools/tests/test_util.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
import nose
3+
import unittest
4+
5+
import numpy as np
6+
from numpy.testing import assert_equal
7+
8+
from pandas.tools.util import cartesian_product
9+
10+
class TestCartesianProduct(unittest.TestCase):
11+
12+
def test_simple(self):
13+
x, y = list('ABC'), [1, 22]
14+
result = cartesian_product([x, y])
15+
expected = [np.array(['A', 'A', 'B', 'B', 'C', 'C']),
16+
np.array([ 1, 22, 1, 22, 1, 22])]
17+
assert_equal(result, expected)
18+
19+
if __name__ == '__main__':
20+
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
21+
exit=False)

pandas/tools/util.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
11
from pandas.core.index import Index
2+
import numpy as np
23

34
def match(needles, haystack):
45
haystack = Index(haystack)
56
needles = Index(needles)
6-
return haystack.get_indexer(needles)
7+
return haystack.get_indexer(needles)
8+
9+
def cartesian_product(X):
10+
'''
11+
Numpy version of itertools.product or pandas.util.compat.product.
12+
Sometimes faster (for large inputs)...
13+
14+
Examples
15+
--------
16+
>>> cartesian_product([list('ABC'), [1, 2]])
17+
[array(['A', 'A', 'B', 'B', 'C', 'C'], dtype='|S1'),
18+
array([1, 2, 1, 2, 1, 2])]
19+
20+
'''
21+
lenX = map(len, X)
22+
cumprodX = np.cumproduct(lenX)
23+
a = np.insert(cumprodX, 0, 1)
24+
b = a[-1] / a[1:]
25+
return [np.tile(np.repeat(x, b[i]),
26+
np.product(a[i]))
27+
for i, x in enumerate(X)]
28+

0 commit comments

Comments
 (0)