Skip to content

Commit 484a8b3

Browse files
Dieter Vandenbusschewesm
Dieter Vandenbussche
authored andcommitted
Add kind argument to DataFrame.plot
1 parent 485b68b commit 484a8b3

File tree

1 file changed

+46
-16
lines changed

1 file changed

+46
-16
lines changed

pandas/core/frame.py

+46-16
Original file line numberDiff line numberDiff line change
@@ -2961,8 +2961,8 @@ def clip_lower(self, threshold):
29612961
#----------------------------------------------------------------------
29622962
# Plotting
29632963

2964-
def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
2965-
figsize=None, grid=True, legend=True, ax=None, **kwds):
2964+
def plot(self, kind='line', subplots=False, sharex=True, sharey=False, use_index=True,
2965+
figsize=None, grid=True, legend=True, rot=30, ax=None, **kwds):
29662966
"""
29672967
Make line plot of DataFrame's series with the index on the x-axis using
29682968
matplotlib / pylab.
@@ -2998,22 +2998,49 @@ def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
29982998
else:
29992999
fig = ax.get_figure()
30003000

3001-
if use_index:
3002-
x = self.index
3003-
else:
3004-
x = range(len(self))
3001+
if kind == 'line':
3002+
if use_index:
3003+
x = self.index
3004+
else:
3005+
x = range(len(self))
3006+
3007+
for i, col in enumerate(_try_sort(self.columns)):
3008+
empty = self[col].count() == 0
3009+
y = self[col].values if not empty else np.zeros(x.shape)
3010+
if subplots:
3011+
ax = axes[i]
3012+
ax.plot(x, y, 'k', label=str(col), **kwds)
3013+
ax.legend(loc='best')
3014+
else:
3015+
ax.plot(x, y, label=str(col), **kwds)
3016+
3017+
ax.grid(grid)
3018+
elif kind == 'bar':
3019+
N = len(self)
3020+
M = len(self.columns)
3021+
xinds = np.arange(N) + 0.25
3022+
colors = ['red', 'green', 'blue', 'yellow', 'black']
3023+
rects = []
3024+
labels = []
3025+
for i, col in enumerate(_try_sort(self.columns)):
3026+
empty = self[col].count() == 0
3027+
y = self[col].values if not empty else np.zeros(x.shape)
3028+
if subplots:
3029+
ax = axes[i]
3030+
ax.bar(xinds, y, 0.5,
3031+
bottom=np.zeros(N), linewidth=1, **kwds)
3032+
ax.set_title(col)
3033+
else:
3034+
rects.append(ax.bar(xinds+i*0.5/M,y,0.5/M,bottom=np.zeros(N),color=colors[i % len(colors)], **kwds))
3035+
labels.append(col)
30053036

3006-
for i, col in enumerate(_try_sort(self.columns)):
3007-
empty = self[col].count() == 0
3008-
y = self[col].values if not empty else np.zeros(x.shape)
3009-
if subplots:
3010-
ax = axes[i]
3011-
ax.plot(x, y, 'k', label=str(col), **kwds)
3012-
ax.legend(loc='best')
3037+
if N < 10:
3038+
fontsize = 12
30133039
else:
3014-
ax.plot(x, y, label=str(col), **kwds)
3040+
fontsize = 10
30153041

3016-
ax.grid(grid)
3042+
ax.set_xticks(xinds + 0.25)
3043+
ax.set_xticklabels(self.index, rotation=rot, fontsize=fontsize)
30173044

30183045
# try to make things prettier
30193046
try:
@@ -3022,7 +3049,10 @@ def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
30223049
pass
30233050

30243051
if legend and not subplots:
3025-
ax.legend(loc='best')
3052+
if kind == 'line':
3053+
ax.legend(loc='best')
3054+
else:
3055+
ax.legend([r[0] for r in rects],labels,loc='best')
30263056

30273057
plt.draw_if_interactive()
30283058

0 commit comments

Comments
 (0)