Skip to content

Commit 09e93c3

Browse files
Chang Shewesm
Chang She
authored andcommitted
BUG: custom colors for bar chart #1540
1 parent 5315aee commit 09e93c3

File tree

3 files changed

+42
-21
lines changed

3 files changed

+42
-21
lines changed

pandas/tests/test_graphics.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def setUp(self):
3939

4040
@slow
4141
def test_plot(self):
42+
import matplotlib.pyplot as plt
43+
import matplotlib.colors as colors
4244
_check_plot_works(self.ts.plot, label='foo')
4345
_check_plot_works(self.ts.plot, use_index=False)
4446
_check_plot_works(self.ts.plot, rot=0)
@@ -53,6 +55,34 @@ def test_plot(self):
5355

5456
Series(np.random.randn(10)).plot(kind='bar',color='black')
5557

58+
default_colors = 'brgyk'
59+
custom_colors = 'rgcby'
60+
61+
plt.close('all')
62+
df = DataFrame(np.random.randn(5, 5))
63+
ax = df.plot(kind='bar')
64+
65+
rects = ax.patches
66+
67+
conv = colors.colorConverter
68+
for i, rect in enumerate(rects[:5]):
69+
xp = conv.to_rgba(default_colors[i])
70+
rs = rect.get_facecolor()
71+
self.assert_(xp, rs)
72+
73+
plt.close('all')
74+
75+
ax = df.plot(kind='bar', color=custom_colors)
76+
77+
rects = ax.patches
78+
79+
conv = colors.colorConverter
80+
for i, rect in enumerate(rects[:5]):
81+
xp = conv.to_rgba(custom_colors[i])
82+
rs = rect.get_facecolor()
83+
self.assert_(xp, rs)
84+
85+
5686
@slow
5787
def test_irregular_datetime(self):
5888
rng = date_range('1/1/2000', '3/1/2000')

pandas/tools/plotting.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,8 @@ def plt(self):
498498

499499
def _get_xticks(self, convert_period=False):
500500
index = self.data.index
501-
is_datetype = (index.inferred_type in ('datetime', 'date',
502-
'datetime64')
503-
or lib.is_time_array(index))
501+
is_datetype = index.inferred_type in ('datetime', 'date',
502+
'datetime64', 'time')
504503

505504
if self.use_index:
506505
if convert_period and isinstance(index, PeriodIndex):
@@ -515,7 +514,6 @@ def _get_xticks(self, convert_period=False):
515514
"""
516515
x = index._mpl_repr()
517516
else:
518-
foo
519517
self._need_to_set_index = True
520518
x = range(len(index))
521519
else:
@@ -750,7 +748,7 @@ def f(ax, x, y, w, start=None, **kwds):
750748
return f
751749

752750
def _make_plot(self):
753-
colors = 'brgyk'
751+
colors = self.kwds.get('color', 'brgyk')
754752
rects = []
755753
labels = []
756754

@@ -765,8 +763,7 @@ def _make_plot(self):
765763
for i, (label, y) in enumerate(self._iter_data()):
766764

767765
kwds = self.kwds.copy()
768-
if 'color' not in kwds:
769-
kwds['color'] = colors[i % len(colors)]
766+
kwds['color'] = colors[i % len(colors)]
770767

771768
if self.subplots:
772769
ax, _ = self._get_ax_and_style(i) #self.axes[i]

pandas/tseries/tests/test_plotting.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -587,18 +587,6 @@ def test_irreg_dtypes(self):
587587
df = DataFrame(np.random.randn(len(idx), 3), idx)
588588
_check_plot_works(df.plot)
589589

590-
#time
591-
plt.close('all')
592-
inc = Series(np.random.randint(1, 15, 3)).cumsum().values
593-
idx = [time(1, 1, i) for i in inc]
594-
df = DataFrame(np.random.randn(len(idx), 3), idx)
595-
ax = df.plot()
596-
ticks = ax.get_xticks()
597-
labels = ax.get_xticklabels()
598-
td = dict(zip(ticks, labels))
599-
for i in range(3):
600-
self.assert_(td[i].get_text() == str(idx[i]))
601-
602590
@slow
603591
def test_time(self):
604592
import matplotlib.pyplot as plt
@@ -618,7 +606,10 @@ def test_time(self):
618606
for t, l in zip(ticks, labels):
619607
m, s = divmod(int(t), 60)
620608
h, m = divmod(m, 60)
621-
self.assert_(time(h, m, s).strftime('%H:%M:%S') == t.get_text())
609+
xp = l.get_text()
610+
if len(xp) > 0:
611+
rs = time(h, m, s).strftime('%H:%M:%S')
612+
self.assert_(xp, rs)
622613

623614
# change xlim
624615
ax.set_xlim('1:30', '5:00')
@@ -629,7 +620,10 @@ def test_time(self):
629620
for t, l in zip(ticks, labels):
630621
m, s = divmod(int(t), 60)
631622
h, m = divmod(m, 60)
632-
self.assert_(time(h, m, s).strftime('%H:%M:%S') == t.get_text())
623+
xp = l.get_text()
624+
if len(xp) > 0:
625+
rs = time(h, m, s).strftime('%H:%M:%S')
626+
self.assert_(xp, rs)
633627

634628
PNG_PATH = 'tmp.png'
635629
def _check_plot_works(f, freq=None, series=None, *args, **kwargs):

0 commit comments

Comments
 (0)