Skip to content

Commit 0518b54

Browse files
committed
ENH: refactoring/mpl tweaking post PR #348
1 parent 484a8b3 commit 0518b54

File tree

2 files changed

+64
-32
lines changed

2 files changed

+64
-32
lines changed

pandas/core/frame.py

+50-32
Original file line numberDiff line numberDiff line change
@@ -2961,8 +2961,9 @@ def clip_lower(self, threshold):
29612961
#----------------------------------------------------------------------
29622962
# Plotting
29632963

2964-
def plot(self, kind='line', subplots=False, sharex=True, sharey=False, use_index=True,
2965-
figsize=None, grid=True, legend=True, rot=30, ax=None, **kwds):
2964+
def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
2965+
figsize=None, grid=True, legend=True, rot=30, ax=None,
2966+
kind='line', **kwds):
29662967
"""
29672968
Make line plot of DataFrame's series with the index on the x-axis using
29682969
matplotlib / pylab.
@@ -2977,6 +2978,7 @@ def plot(self, kind='line', subplots=False, sharex=True, sharey=False, use_index
29772978
In case subplots=True, share y axis
29782979
use_index : boolean, default True
29792980
Use index as ticks for x axis
2981+
kind : {'line', 'bar'}
29802982
kwds : keywords
29812983
Options to pass to Axis.plot
29822984
@@ -2995,6 +2997,7 @@ def plot(self, kind='line', subplots=False, sharex=True, sharey=False, use_index
29952997
if ax is None:
29962998
fig = plt.figure(figsize=figsize)
29972999
ax = fig.add_subplot(111)
3000+
axes = [ax]
29983001
else:
29993002
fig = ax.get_figure()
30003003

@@ -3015,46 +3018,61 @@ def plot(self, kind='line', subplots=False, sharex=True, sharey=False, use_index
30153018
ax.plot(x, y, label=str(col), **kwds)
30163019

30173020
ax.grid(grid)
3018-
elif kind == 'bar':
3019-
N = len(self)
3020-
M = len(self.columns)
3021-
xinds = np.arange(N) + 0.25
3022-
colors = ['red', 'green', 'blue', 'yellow', 'black']
3023-
rects = []
3024-
labels = []
3025-
for i, col in enumerate(_try_sort(self.columns)):
3026-
empty = self[col].count() == 0
3027-
y = self[col].values if not empty else np.zeros(x.shape)
3028-
if subplots:
3029-
ax = axes[i]
3030-
ax.bar(xinds, y, 0.5,
3031-
bottom=np.zeros(N), linewidth=1, **kwds)
3032-
ax.set_title(col)
3033-
else:
3034-
rects.append(ax.bar(xinds+i*0.5/M,y,0.5/M,bottom=np.zeros(N),color=colors[i % len(colors)], **kwds))
3035-
labels.append(col)
30363021

3037-
if N < 10:
3038-
fontsize = 12
3039-
else:
3040-
fontsize = 10
3041-
3042-
ax.set_xticks(xinds + 0.25)
3043-
ax.set_xticklabels(self.index, rotation=rot, fontsize=fontsize)
3022+
if legend and not subplots:
3023+
ax.legend(loc='best')
3024+
elif kind == 'bar':
3025+
self._bar_plot(axes, subplots=subplots, grid=grid, rot=rot,
3026+
legend=legend)
30443027

30453028
# try to make things prettier
30463029
try:
30473030
fig.autofmt_xdate()
30483031
except Exception: # pragma: no cover
30493032
pass
30503033

3051-
if legend and not subplots:
3052-
if kind == 'line':
3053-
ax.legend(loc='best')
3034+
plt.draw_if_interactive()
3035+
3036+
def _bar_plot(self, axes, subplots=False, use_index=True, grid=True,
3037+
rot=30, legend=True, **kwds):
3038+
N, K = self.shape
3039+
xinds = np.arange(N) + 0.25
3040+
colors = 'rgbyk'
3041+
rects = []
3042+
labels = []
3043+
3044+
if not subplots:
3045+
ax = axes[0]
3046+
3047+
for i, col in enumerate(_try_sort(self.columns)):
3048+
empty = self[col].count() == 0
3049+
y = self[col].values if not empty else np.zeros(len(self))
3050+
if subplots:
3051+
ax = axes[i]
3052+
ax.bar(xinds, y, 0.5,
3053+
bottom=np.zeros(N), linewidth=1, **kwds)
3054+
ax.set_title(col)
30543055
else:
3055-
ax.legend([r[0] for r in rects],labels,loc='best')
3056+
rects.append(ax.bar(xinds + i * 0.5/K, y, 0.5/K,
3057+
bottom=np.zeros(N),
3058+
color=colors[i % len(colors)], **kwds))
3059+
labels.append(col)
30563060

3057-
plt.draw_if_interactive()
3061+
if N < 10:
3062+
fontsize = 12
3063+
else:
3064+
fontsize = 10
3065+
3066+
ax.set_xticks(xinds + 0.25)
3067+
ax.set_xticklabels(self.index, rotation=rot, fontsize=fontsize)
3068+
3069+
if legend and not subplots:
3070+
fig = ax.get_figure()
3071+
fig.legend([r[0] for r in rects], labels, loc='upper center',
3072+
fancybox=True, ncol=6, mode='expand')
3073+
3074+
import matplotlib.pyplot as plt
3075+
plt.subplots_adjust(top=0.8)
30583076

30593077
def hist(self, grid=True, **kwds):
30603078
"""

pandas/tests/test_graphics.py

+14
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,20 @@ def test_plot(self):
6161
_check_plot_works(df.plot, subplots=True)
6262
_check_plot_works(df.plot, subplots=True, use_index=False)
6363

64+
def test_plot_bar(self):
65+
df = DataFrame(np.random.randn(6, 4),
66+
index=['a', 'b', 'c', 'd', 'e', 'f'],
67+
columns=['one', 'two', 'three', 'four'])
68+
69+
_check_plot_works(df.plot, kind='bar')
70+
_check_plot_works(df.plot, kind='bar', legend=False)
71+
_check_plot_works(df.plot, kind='bar', subplots=True)
72+
73+
df = DataFrame(np.random.randn(6, 15),
74+
index=['a', 'b', 'c', 'd', 'e', 'f'],
75+
columns=range(15))
76+
_check_plot_works(df.plot, kind='bar')
77+
6478
def test_hist(self):
6579
df = DataFrame(np.random.randn(100, 4))
6680
_check_plot_works(df.hist)

0 commit comments

Comments
 (0)