Skip to content

Commit 583b4d4

Browse files
author
Tom Augspurger
committed
Merge pull request #6678 from sinhrks/legend_pr
BUG: legend behaves inconsistently when plotting to the same axes
2 parents 92852bd + 78cd15b commit 583b4d4

File tree

4 files changed

+177
-118
lines changed

4 files changed

+177
-118
lines changed

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ Bug Fixes
390390
group match wasn't renamed to the group name
391391
- Bug in ``DataFrame.to_csv`` where setting `index` to `False` ignored the
392392
`header` kwarg (:issue:`6186`)
393+
- Bug in `DataFrame.plot` and `Series.plot` legend behave inconsistently when plotting to the same axes repeatedly (:issue:`6678`)
393394

394395
pandas 0.13.1
395396
-------------

doc/source/visualization.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ for controlling the look of the plot:
6666
.. ipython:: python
6767
6868
@savefig series_plot_basic2.png
69-
plt.figure(); ts.plot(style='k--', label='Series'); plt.legend()
69+
plt.figure(); ts.plot(style='k--', label='Series');
7070
7171
On DataFrame, ``plot`` is a convenience to plot all of the columns with labels:
7272

@@ -76,7 +76,7 @@ On DataFrame, ``plot`` is a convenience to plot all of the columns with labels:
7676
df = df.cumsum()
7777
7878
@savefig frame_plot_basic.png
79-
plt.figure(); df.plot(); plt.legend(loc='best')
79+
plt.figure(); df.plot();
8080
8181
You may set the ``legend`` argument to ``False`` to hide the legend, which is
8282
shown by default.
@@ -91,7 +91,7 @@ Some other options are available, like plotting each Series on a different axis:
9191
.. ipython:: python
9292
9393
@savefig frame_plot_subplots.png
94-
df.plot(subplots=True, figsize=(6, 6)); plt.legend(loc='best')
94+
df.plot(subplots=True, figsize=(6, 6));
9595
9696
You may pass ``logy`` to get a log-scale Y axis.
9797

pandas/tests/test_graphics.py

+96-19
Original file line numberDiff line numberDiff line change
@@ -490,29 +490,34 @@ def test_subplots(self):
490490
df = DataFrame(np.random.rand(10, 3),
491491
index=list(string.ascii_letters[:10]))
492492

493-
axes = df.plot(subplots=True, sharex=True, legend=True)
493+
for kind in ['bar', 'barh', 'line']:
494+
axes = df.plot(kind=kind, subplots=True, sharex=True, legend=True)
494495

495-
for ax in axes:
496-
self.assertIsNotNone(ax.get_legend())
497-
498-
axes = df.plot(subplots=True, sharex=True)
499-
for ax in axes[:-2]:
500-
[self.assert_(not label.get_visible())
501-
for label in ax.get_xticklabels()]
502-
[self.assert_(label.get_visible())
503-
for label in ax.get_yticklabels()]
496+
for ax, column in zip(axes, df.columns):
497+
self._check_legend_labels(ax, [column])
504498

505-
[self.assert_(label.get_visible())
506-
for label in axes[-1].get_xticklabels()]
507-
[self.assert_(label.get_visible())
508-
for label in axes[-1].get_yticklabels()]
499+
axes = df.plot(kind=kind, subplots=True, sharex=True)
500+
for ax in axes[:-2]:
501+
[self.assert_(not label.get_visible())
502+
for label in ax.get_xticklabels()]
503+
[self.assert_(label.get_visible())
504+
for label in ax.get_yticklabels()]
509505

510-
axes = df.plot(subplots=True, sharex=False)
511-
for ax in axes:
512506
[self.assert_(label.get_visible())
513-
for label in ax.get_xticklabels()]
507+
for label in axes[-1].get_xticklabels()]
514508
[self.assert_(label.get_visible())
515-
for label in ax.get_yticklabels()]
509+
for label in axes[-1].get_yticklabels()]
510+
511+
axes = df.plot(kind=kind, subplots=True, sharex=False)
512+
for ax in axes:
513+
[self.assert_(label.get_visible())
514+
for label in ax.get_xticklabels()]
515+
[self.assert_(label.get_visible())
516+
for label in ax.get_yticklabels()]
517+
518+
axes = df.plot(kind=kind, subplots=True, legend=False)
519+
for ax in axes:
520+
self.assertTrue(ax.get_legend() is None)
516521

517522
@slow
518523
def test_bar_colors(self):
@@ -873,7 +878,7 @@ def test_kde(self):
873878
_check_plot_works(df.plot, kind='kde')
874879
_check_plot_works(df.plot, kind='kde', subplots=True)
875880
ax = df.plot(kind='kde')
876-
self.assertIsNotNone(ax.get_legend())
881+
self._check_legend_labels(ax, df.columns)
877882
axes = df.plot(kind='kde', logy=True, subplots=True)
878883
for ax in axes:
879884
self.assertEqual(ax.get_yscale(), 'log')
@@ -1046,6 +1051,64 @@ def test_plot_int_columns(self):
10461051
df = DataFrame(randn(100, 4)).cumsum()
10471052
_check_plot_works(df.plot, legend=True)
10481053

1054+
def _check_legend_labels(self, ax, labels):
1055+
import pandas.core.common as com
1056+
labels = [com.pprint_thing(l) for l in labels]
1057+
self.assertTrue(ax.get_legend() is not None)
1058+
legend_labels = [t.get_text() for t in ax.get_legend().get_texts()]
1059+
self.assertEqual(labels, legend_labels)
1060+
1061+
@slow
1062+
def test_df_legend_labels(self):
1063+
kinds = 'line', 'bar', 'barh', 'kde', 'density'
1064+
df = DataFrame(randn(3, 3), columns=['a', 'b', 'c'])
1065+
df2 = DataFrame(randn(3, 3), columns=['d', 'e', 'f'])
1066+
df3 = DataFrame(randn(3, 3), columns=['g', 'h', 'i'])
1067+
df4 = DataFrame(randn(3, 3), columns=['j', 'k', 'l'])
1068+
1069+
for kind in kinds:
1070+
ax = df.plot(kind=kind, legend=True)
1071+
self._check_legend_labels(ax, df.columns)
1072+
1073+
ax = df2.plot(kind=kind, legend=False, ax=ax)
1074+
self._check_legend_labels(ax, df.columns)
1075+
1076+
ax = df3.plot(kind=kind, legend=True, ax=ax)
1077+
self._check_legend_labels(ax, df.columns + df3.columns)
1078+
1079+
ax = df4.plot(kind=kind, legend='reverse', ax=ax)
1080+
expected = list(df.columns + df3.columns) + list(reversed(df4.columns))
1081+
self._check_legend_labels(ax, expected)
1082+
1083+
# Secondary Y
1084+
ax = df.plot(legend=True, secondary_y='b')
1085+
self._check_legend_labels(ax, ['a', 'b (right)', 'c'])
1086+
ax = df2.plot(legend=False, ax=ax)
1087+
self._check_legend_labels(ax, ['a', 'b (right)', 'c'])
1088+
ax = df3.plot(kind='bar', legend=True, secondary_y='h', ax=ax)
1089+
self._check_legend_labels(ax, ['a', 'b (right)', 'c', 'g', 'h (right)', 'i'])
1090+
1091+
# Time Series
1092+
ind = date_range('1/1/2014', periods=3)
1093+
df = DataFrame(randn(3, 3), columns=['a', 'b', 'c'], index=ind)
1094+
df2 = DataFrame(randn(3, 3), columns=['d', 'e', 'f'], index=ind)
1095+
df3 = DataFrame(randn(3, 3), columns=['g', 'h', 'i'], index=ind)
1096+
ax = df.plot(legend=True, secondary_y='b')
1097+
self._check_legend_labels(ax, ['a', 'b (right)', 'c'])
1098+
ax = df2.plot(legend=False, ax=ax)
1099+
self._check_legend_labels(ax, ['a', 'b (right)', 'c'])
1100+
ax = df3.plot(legend=True, ax=ax)
1101+
self._check_legend_labels(ax, ['a', 'b (right)', 'c', 'g', 'h', 'i'])
1102+
1103+
# scatter
1104+
ax = df.plot(kind='scatter', x='a', y='b', label='data1')
1105+
self._check_legend_labels(ax, ['data1'])
1106+
ax = df2.plot(kind='scatter', x='d', y='e', legend=False,
1107+
label='data2', ax=ax)
1108+
self._check_legend_labels(ax, ['data1'])
1109+
ax = df3.plot(kind='scatter', x='g', y='h', label='data3', ax=ax)
1110+
self._check_legend_labels(ax, ['data1', 'data3'])
1111+
10491112
def test_legend_name(self):
10501113
multi = DataFrame(randn(4, 4),
10511114
columns=[np.array(['a', 'a', 'b', 'b']),
@@ -1056,6 +1119,20 @@ def test_legend_name(self):
10561119
leg_title = ax.legend_.get_title()
10571120
self.assertEqual(leg_title.get_text(), 'group,individual')
10581121

1122+
df = DataFrame(randn(5, 5))
1123+
ax = df.plot(legend=True, ax=ax)
1124+
leg_title = ax.legend_.get_title()
1125+
self.assertEqual(leg_title.get_text(), 'group,individual')
1126+
1127+
df.columns.name = 'new'
1128+
ax = df.plot(legend=False, ax=ax)
1129+
leg_title = ax.legend_.get_title()
1130+
self.assertEqual(leg_title.get_text(), 'group,individual')
1131+
1132+
ax = df.plot(legend=True, ax=ax)
1133+
leg_title = ax.legend_.get_title()
1134+
self.assertEqual(leg_title.get_text(), 'new')
1135+
10591136
def _check_plot_fails(self, f, *args, **kwargs):
10601137
with tm.assertRaises(Exception):
10611138
f(*args, **kwargs)

0 commit comments

Comments
 (0)