Skip to content

Commit 8edebb2

Browse files
committed
ENH: Allow df plotting with style by some columns.
DataFrame.plot(style=<[..] or {}>) no longer requires style to hold a value for each column of the DataFrame.
1 parent be94b91 commit 8edebb2

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

pandas/tests/test_graphics.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,12 @@ def test_plot(self):
199199
(u'\u03b3', 5),
200200
(u'\u03b4', 6),
201201
(u'\u03b4', 7)], names=['i0', 'i1'])
202-
columns = pandas.MultiIndex.from_tuples([('bar', u'\u0394'),
202+
columns = MultiIndex.from_tuples([('bar', u'\u0394'),
203203
('bar', u'\u0395')], names=['c0', 'c1'])
204-
df = pandas.DataFrame(np.random.randint(0, 10, (8, 2)),
205-
columns=columns,
206-
index=index)
207-
df.plot(title=u'\u03A3')
204+
df = DataFrame(np.random.randint(0, 10, (8, 2)),
205+
columns=columns,
206+
index=index)
207+
_check_plot_works(df.plot, title=u'\u03A3')
208208

209209
@slow
210210
def test_plot_xy(self):
@@ -429,22 +429,17 @@ def _check_plot_fails(self, f, *args, **kwargs):
429429
def test_style_by_column(self):
430430
import matplotlib.pyplot as plt
431431
fig = plt.gcf()
432-
fig.clf()
433-
fig.add_subplot(111)
434432

435433
df = DataFrame(np.random.randn(100, 3))
436-
markers = {0: '^', 1: '+', 2: 'o'}
437-
ax = df.plot(style=markers)
438-
for i, l in enumerate(ax.get_lines()):
439-
self.assertEqual(l.get_marker(), markers[i])
440-
441-
fig.clf()
442-
fig.add_subplot(111)
443-
df = DataFrame(np.random.randn(100, 3))
444-
markers = ['^', '+', 'o']
445-
ax = df.plot(style=markers)
446-
for i, l in enumerate(ax.get_lines()):
447-
self.assertEqual(l.get_marker(), markers[i])
434+
for markers in [{0: '^', 1: '+', 2: 'o'},
435+
{0: '^', 1: '+'},
436+
['^', '+', 'o'],
437+
['^', '+']]:
438+
fig.clf()
439+
fig.add_subplot(111)
440+
ax = df.plot(style=markers)
441+
for i, l in enumerate(ax.get_lines()[:len(markers)]):
442+
self.assertEqual(l.get_marker(), markers[i])
448443

449444
class TestDataFrameGroupByPlots(unittest.TestCase):
450445

pandas/tools/plotting.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -784,9 +784,12 @@ def _get_style(self, i, col_name):
784784

785785
if self.style is not None:
786786
if isinstance(self.style, list):
787-
style = self.style[i]
787+
try:
788+
style = self.style[i]
789+
except IndexError:
790+
pass
788791
elif isinstance(self.style, dict):
789-
style = self.style[col_name]
792+
style = self.style.get(col_name, style)
790793
else:
791794
style = self.style
792795

0 commit comments

Comments
 (0)