Skip to content

Commit f63de79

Browse files
committed
ENH/BUG: color cannot be applied to line subplots
1 parent f0ac930 commit f63de79

File tree

2 files changed

+113
-24
lines changed

2 files changed

+113
-24
lines changed

pandas/tests/test_graphics.py

+87
Original file line numberDiff line numberDiff line change
@@ -2800,6 +2800,48 @@ def test_line_colors(self):
28002800
ax = df.ix[:, [0]].plot(color='DodgerBlue')
28012801
self._check_colors(ax.lines, linecolors=['DodgerBlue'])
28022802

2803+
@slow
2804+
def test_line_colors_and_styles_subplots(self):
2805+
from matplotlib import cm
2806+
2807+
custom_colors = 'rgcby'
2808+
df = DataFrame(randn(5, 5))
2809+
2810+
axes = df.plot(subplots=True)
2811+
for ax, c in zip(axes, list(custom_colors)):
2812+
self._check_colors(ax.get_lines(), linecolors=['k'])
2813+
tm.close()
2814+
2815+
axes = df.plot(color=custom_colors, subplots=True)
2816+
for ax, c in zip(axes, list(custom_colors)):
2817+
self._check_colors(ax.get_lines(), linecolors=[c])
2818+
tm.close()
2819+
2820+
rgba_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
2821+
for cmap in ['jet', cm.jet]:
2822+
axes = df.plot(colormap=cmap, subplots=True)
2823+
for ax, c in zip(axes, rgba_colors):
2824+
self._check_colors(ax.get_lines(), linecolors=[c])
2825+
tm.close()
2826+
2827+
# make color a list if plotting one column frame
2828+
# handles cases like df.plot(color='DodgerBlue')
2829+
axes = df.ix[:, [0]].plot(color='DodgerBlue', subplots=True)
2830+
self._check_colors(axes[0].lines, linecolors=['DodgerBlue'])
2831+
2832+
# single character style
2833+
axes = df.plot(style='r', subplots=True)
2834+
for ax in axes:
2835+
self._check_colors(ax.get_lines(), linecolors=['r'])
2836+
tm.close()
2837+
2838+
# list of styles
2839+
styles = list('rgcby')
2840+
axes = df.plot(style=styles, subplots=True)
2841+
for ax, c in zip(axes, styles):
2842+
self._check_colors(ax.get_lines(), linecolors=[c])
2843+
tm.close()
2844+
28032845
@slow
28042846
def test_area_colors(self):
28052847
from matplotlib import cm
@@ -2898,6 +2940,51 @@ def test_kde_colors(self):
28982940
rgba_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
28992941
self._check_colors(ax.get_lines(), linecolors=rgba_colors)
29002942

2943+
@slow
2944+
def test_kde_colors_and_styles_subplots(self):
2945+
tm._skip_if_no_scipy()
2946+
_skip_if_no_scipy_gaussian_kde()
2947+
2948+
from matplotlib import cm
2949+
2950+
custom_colors = 'rgcby'
2951+
df = DataFrame(randn(5, 5))
2952+
2953+
axes = df.plot(kind='kde', subplots=True)
2954+
for ax, c in zip(axes, list(custom_colors)):
2955+
self._check_colors(ax.get_lines(), linecolors=['k'])
2956+
tm.close()
2957+
2958+
axes = df.plot(kind='kde', color=custom_colors, subplots=True)
2959+
for ax, c in zip(axes, list(custom_colors)):
2960+
self._check_colors(ax.get_lines(), linecolors=[c])
2961+
tm.close()
2962+
2963+
rgba_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
2964+
for cmap in ['jet', cm.jet]:
2965+
axes = df.plot(kind='kde', colormap=cmap, subplots=True)
2966+
for ax, c in zip(axes, rgba_colors):
2967+
self._check_colors(ax.get_lines(), linecolors=[c])
2968+
tm.close()
2969+
2970+
# make color a list if plotting one column frame
2971+
# handles cases like df.plot(color='DodgerBlue')
2972+
axes = df.ix[:, [0]].plot(kind='kde', color='DodgerBlue', subplots=True)
2973+
self._check_colors(axes[0].lines, linecolors=['DodgerBlue'])
2974+
2975+
# single character style
2976+
axes = df.plot(kind='kde', style='r', subplots=True)
2977+
for ax in axes:
2978+
self._check_colors(ax.get_lines(), linecolors=['r'])
2979+
tm.close()
2980+
2981+
# list of styles
2982+
styles = list('rgcby')
2983+
axes = df.plot(kind='kde', style=styles, subplots=True)
2984+
for ax, c in zip(axes, styles):
2985+
self._check_colors(ax.get_lines(), linecolors=[c])
2986+
tm.close()
2987+
29012988
@slow
29022989
def test_boxplot_colors(self):
29032990

pandas/tools/plotting.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -1264,36 +1264,39 @@ def on_right(self, i):
12641264
isinstance(self.secondary_y, (tuple, list, np.ndarray, Index))):
12651265
return self.data.columns[i] in self.secondary_y
12661266

1267-
def _get_style(self, i, col_name):
1268-
style = ''
1269-
if self.subplots:
1270-
style = 'k'
1267+
def _get_colors(self, num_colors=None, color_kwds='color'):
1268+
if num_colors is None:
1269+
num_colors = self.nseries
1270+
1271+
return _get_standard_colors(num_colors=num_colors,
1272+
colormap=self.colormap,
1273+
color=self.kwds.get(color_kwds))
12711274

1275+
def _apply_style_colors(self, colors, kwds, col_num, label):
1276+
"""
1277+
Manage style and color based on column number and its label.
1278+
Returns tuple of appropriate style and kwds which "color" may be added.
1279+
"""
1280+
style = None
12721281
if self.style is not None:
12731282
if isinstance(self.style, list):
12741283
try:
1275-
style = self.style[i]
1284+
style = self.style[col_num]
12761285
except IndexError:
12771286
pass
12781287
elif isinstance(self.style, dict):
1279-
style = self.style.get(col_name, style)
1288+
style = self.style.get(label, style)
12801289
else:
12811290
style = self.style
12821291

1283-
return style or None
1284-
1285-
def _get_colors(self, num_colors=None, color_kwds='color'):
1286-
if num_colors is None:
1287-
num_colors = self.nseries
1288-
1289-
return _get_standard_colors(num_colors=num_colors,
1290-
colormap=self.colormap,
1291-
color=self.kwds.get(color_kwds))
1292-
1293-
def _maybe_add_color(self, colors, kwds, style, i):
12941292
has_color = 'color' in kwds or self.colormap is not None
1295-
if has_color and (style is None or re.match('[a-z]+', style) is None):
1296-
kwds['color'] = colors[i % len(colors)]
1293+
if has_color:
1294+
if style is None or re.match('[a-z]+', style) is None:
1295+
kwds['color'] = colors[col_num % len(colors)]
1296+
else:
1297+
if self.subplots and style is None:
1298+
style = 'k'
1299+
return style, kwds
12971300

12981301
def _parse_errorbars(self, label, err):
12991302
'''
@@ -1618,9 +1621,8 @@ def _make_plot(self):
16181621
colors = self._get_colors()
16191622
for i, (label, y) in enumerate(it):
16201623
ax = self._get_ax(i)
1621-
style = self._get_style(i, label)
16221624
kwds = self.kwds.copy()
1623-
self._maybe_add_color(colors, kwds, style, i)
1625+
style, kwds = self._apply_style_colors(colors, kwds, i, label)
16241626

16251627
errors = self._get_errorbars(label=label, index=i)
16261628
kwds = dict(kwds, **errors)
@@ -1977,13 +1979,13 @@ def _make_plot(self):
19771979
colors = self._get_colors()
19781980
for i, (label, y) in enumerate(self._iter_data()):
19791981
ax = self._get_ax(i)
1980-
style = self._get_style(i, label)
1981-
label = com.pprint_thing(label)
19821982

19831983
kwds = self.kwds.copy()
1984+
1985+
label = com.pprint_thing(label)
19841986
kwds['label'] = label
1985-
self._maybe_add_color(colors, kwds, style, i)
19861987

1988+
style, kwds = self._apply_style_colors(colors, kwds, i, label)
19871989
if style is not None:
19881990
kwds['style'] = style
19891991

0 commit comments

Comments
 (0)