Skip to content

Commit a022d45

Browse files
committed
ENH: DataFrame plot improvements, make pivot_table more flexible per #941
1 parent 97f55b9 commit a022d45

File tree

5 files changed

+67
-25
lines changed

5 files changed

+67
-25
lines changed

pandas/core/frame.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -3845,8 +3845,9 @@ def boxplot(self, column=None, by=None, ax=None, fontsize=None,
38453845
return ax
38463846

38473847
def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
3848-
figsize=None, grid=True, legend=True, rot=30, ax=None,
3849-
kind='line', sort_columns=True, **kwds):
3848+
figsize=None, grid=True, legend=True, rot=30, ax=None, title=None,
3849+
xlim=None, ylim=None, xticks=None, yticks=None, kind='line',
3850+
sort_columns=True, fontsize=None, **kwds):
38503851
"""
38513852
Make line plot of DataFrame's series with the index on the x-axis using
38523853
matplotlib / pylab.
@@ -3934,22 +3935,37 @@ def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
39343935
ax_.set_xticklabels(xticklabels, rotation=rot)
39353936
elif kind == 'bar':
39363937
self._bar_plot(axes, subplots=subplots, grid=grid, rot=rot,
3937-
legend=legend)
3938+
legend=legend, ax=ax, fontsize=fontsize)
39383939

3939-
if not subplots or (subplots and sharex):
3940+
if self.index.is_all_dates and not subplots or (subplots and sharex):
39403941
try:
39413942
fig.autofmt_xdate()
39423943
except Exception: # pragma: no cover
39433944
pass
39443945

3946+
if yticks is not None:
3947+
ax.set_yticks(yticks)
3948+
3949+
if xticks is not None:
3950+
ax.set_xticks(xticks)
3951+
3952+
if ylim is not None:
3953+
ax.set_ylim(ylim)
3954+
3955+
if xlim is not None:
3956+
ax.set_xlim(xlim)
3957+
3958+
if title and not subplots:
3959+
ax.set_title(title)
3960+
39453961
plt.draw_if_interactive()
39463962
if subplots:
39473963
return axes
39483964
else:
39493965
return ax
39503966

39513967
def _bar_plot(self, axes, subplots=False, use_index=True, grid=True,
3952-
rot=30, legend=True, **kwds):
3968+
rot=30, legend=True, ax=None, fontsize=None, **kwds):
39533969
import pandas.tools.plotting as gfx
39543970

39553971
N, K = self.shape
@@ -3958,7 +3974,7 @@ def _bar_plot(self, axes, subplots=False, use_index=True, grid=True,
39583974
rects = []
39593975
labels = []
39603976

3961-
if not subplots:
3977+
if not subplots and ax is None:
39623978
ax = axes[0]
39633979

39643980
for i, col in enumerate(self.columns):
@@ -3970,17 +3986,20 @@ def _bar_plot(self, axes, subplots=False, use_index=True, grid=True,
39703986
bottom=np.zeros(N), linewidth=1, **kwds)
39713987
ax.set_title(col)
39723988
else:
3973-
rects.append(ax.bar(xinds + i * 0.5 / K, y, 0.5 / K,
3989+
rects.append(ax.bar(xinds + i * 0.75 / K, y, 0.75 / K,
39743990
bottom=np.zeros(N), label=col,
39753991
color=colors[i % len(colors)], **kwds))
39763992
labels.append(col)
39773993

3978-
if N < 10:
3979-
fontsize = 12
3980-
else:
3981-
fontsize = 10
3994+
if fontsize is None:
3995+
if N < 10:
3996+
fontsize = 12
3997+
else:
3998+
fontsize = 10
3999+
4000+
ax.set_xlim([xinds[0] - 1, xinds[-1] + 1])
39824001

3983-
ax.set_xticks(xinds + 0.25)
4002+
ax.set_xticks(xinds + 0.375)
39844003
ax.set_xticklabels([gfx._stringify(key) for key in self.index],
39854004
rotation=rot,
39864005
fontsize=fontsize)

pandas/core/groupby.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _python_apply_general(self, func, *args, **kwargs):
342342

343343
not_indexed_same = False
344344
for key, group in self:
345-
group.name = key
345+
object.__setattr__(group, 'name', key)
346346

347347
# group might be modified
348348
group_axes = _get_axes(group)
@@ -1003,7 +1003,7 @@ def transform(self, func, *args, **kwargs):
10031003
result = self.obj.copy()
10041004

10051005
for name, group in self:
1006-
group.name = name
1006+
object.__setattr__(group, 'name', name)
10071007
res = func(group, *args, **kwargs)
10081008
indexer = self.obj.index.get_indexer(group.index)
10091009
np.put(result, indexer, res)
@@ -1363,7 +1363,7 @@ def transform(self, func, *args, **kwargs):
13631363

13641364
obj = self._obj_with_exclusions
13651365
for name, group in self:
1366-
group.name = name
1366+
object.__setattr__(group, 'name', name)
13671367

13681368
try:
13691369
wrapper = lambda x: func(x, *args, **kwargs)

pandas/tools/pivot.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas.tools.merge import concat
55
import pandas.core.common as com
66
import numpy as np
7+
import types
78

89
def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
910
fill_value=None, margins=False):
@@ -16,10 +17,10 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
1617
----------
1718
data : DataFrame
1819
values : column to aggregate, optional
19-
rows : list
20-
Columns to group on the x-axis of the pivot table
21-
cols : list
22-
Columns to group on the x-axis of the pivot table
20+
rows : list of column names or arrays to group on
21+
Keys to group on the x-axis of the pivot table
22+
cols : list of column names or arrays to group on
23+
Keys to group on the x-axis of the pivot table
2324
aggfunc : function, default numpy.mean, or list of functions
2425
If list of functions passed, the resulting pivot table will have
2526
hierarchical columns whose top level are the function names (inferred
@@ -83,14 +84,23 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
8384
values = list(data.columns.drop(keys))
8485

8586
if values_passed:
86-
data = data[keys + values]
87+
to_filter = []
88+
for x in keys + values:
89+
try:
90+
if x in data:
91+
to_filter.append(x)
92+
except TypeError:
93+
pass
94+
if len(to_filter) < len(data.columns):
95+
data = data[to_filter]
8796

8897
grouped = data.groupby(keys)
8998
agged = grouped.agg(aggfunc)
9099

91100
table = agged
92-
for k in cols:
93-
table = table.unstack(level=k)
101+
for i in range(len(cols)):
102+
name = table.index.names[len(rows)]
103+
table = table.unstack(name)
94104

95105
if fill_value is not None:
96106
table = table.fillna(value=fill_value)
@@ -183,7 +193,8 @@ def _all_key(key):
183193
def _convert_by(by):
184194
if by is None:
185195
by = []
186-
elif np.isscalar(by):
196+
elif (np.isscalar(by) or isinstance(by, np.ndarray)
197+
or hasattr(by, '__call__')):
187198
by = [by]
188199
else:
189200
by = list(by)

pandas/tools/tests/test_pivot.py

+12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def test_pivot_table(self):
4747
expected = self.data.groupby(rows + [cols])['D'].agg(np.mean).unstack()
4848
tm.assert_frame_equal(table, expected)
4949

50+
def test_pass_array(self):
51+
result = self.data.pivot_table('D', rows=self.data.A, cols=self.data.C)
52+
expected = self.data.pivot_table('D', rows='A', cols='C')
53+
tm.assert_frame_equal(result, expected)
54+
55+
def test_pass_function(self):
56+
result = self.data.pivot_table('D', rows=lambda x: x // 5,
57+
cols=self.data.C)
58+
expected = self.data.pivot_table('D', rows=self.data.index // 5,
59+
cols='C')
60+
tm.assert_frame_equal(result, expected)
61+
5062
def test_pivot_table_multiple(self):
5163
rows = ['A', 'B']
5264
cols= 'C'

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@
166166

167167
MAJOR = 0
168168
MINOR = 7
169-
MICRO = 2
170-
ISRELEASED = True
169+
MICRO = 3
170+
ISRELEASED = False
171171
VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO)
172172
QUALIFIER = ''
173173

0 commit comments

Comments
 (0)