Skip to content

Commit 3995b2b

Browse files
committed
ENH/BUG: color cannot be applied to line subplots
1 parent 8f0f417 commit 3995b2b

File tree

3 files changed

+133
-24
lines changed

3 files changed

+133
-24
lines changed

doc/source/whatsnew/v0.16.1.txt

+9
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ API changes
207207

208208
- By default, ``read_csv`` and ``read_table`` will now try to infer the compression type based on the file extension. Set ``compression=None`` to restore the previous behavior (no decompression). (:issue:`9770`)
209209

210+
211+
- Line and kde plot with ``subplots=True`` now uses default colors, not all black. Specify ``color='k'`` to draw all lines in black (:issue:`9894`)
212+
213+
210214
.. _whatsnew_0161.performance:
211215

212216
Performance Improvements
@@ -256,6 +260,11 @@ Bug Fixes
256260
- Bug in which ``SparseDataFrame`` could not take `nan` as a column name (:issue:`8822`)
257261
- Bug in ``to_msgpack`` and ``read_msgpack`` zlib and blosc compression support (:issue:`9783`)
258262

263+
264+
265+
- Bug in line and kde plot cannot accept multiple colors when ``subplots=True`` (:issue:`9894`)
266+
267+
259268
- Bug ``GroupBy.size`` doesn't attach index name properly if grouped by ``TimeGrouper`` (:issue:`9925`)
260269
- Bug causing an exception in slice assignments because ``length_of_indexer`` returns wrong results (:issue:`9995`)
261270
- Bug in csv parser causing lines with initial whitespace plus one non-space character to be skipped. (:issue:`9710`)

pandas/tests/test_graphics.py

+101
Original file line numberDiff line numberDiff line change
@@ -2836,6 +2836,55 @@ def test_line_colors(self):
28362836
ax = df.ix[:, [0]].plot(color='DodgerBlue')
28372837
self._check_colors(ax.lines, linecolors=['DodgerBlue'])
28382838

2839+
@slow
2840+
def test_line_colors_and_styles_subplots(self):
2841+
from matplotlib import cm
2842+
default_colors = self.plt.rcParams.get('axes.color_cycle')
2843+
2844+
df = DataFrame(randn(5, 5))
2845+
2846+
axes = df.plot(subplots=True)
2847+
for ax, c in zip(axes, list(default_colors)):
2848+
self._check_colors(ax.get_lines(), linecolors=c)
2849+
tm.close()
2850+
2851+
# single color
2852+
axes = df.plot(subplots=True, color='k')
2853+
for ax in axes:
2854+
self._check_colors(ax.get_lines(), linecolors=['k'])
2855+
tm.close()
2856+
2857+
custom_colors = 'rgcby'
2858+
axes = df.plot(color=custom_colors, subplots=True)
2859+
for ax, c in zip(axes, list(custom_colors)):
2860+
self._check_colors(ax.get_lines(), linecolors=[c])
2861+
tm.close()
2862+
2863+
rgba_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
2864+
for cmap in ['jet', cm.jet]:
2865+
axes = df.plot(colormap=cmap, subplots=True)
2866+
for ax, c in zip(axes, rgba_colors):
2867+
self._check_colors(ax.get_lines(), linecolors=[c])
2868+
tm.close()
2869+
2870+
# make color a list if plotting one column frame
2871+
# handles cases like df.plot(color='DodgerBlue')
2872+
axes = df.ix[:, [0]].plot(color='DodgerBlue', subplots=True)
2873+
self._check_colors(axes[0].lines, linecolors=['DodgerBlue'])
2874+
2875+
# single character style
2876+
axes = df.plot(style='r', subplots=True)
2877+
for ax in axes:
2878+
self._check_colors(ax.get_lines(), linecolors=['r'])
2879+
tm.close()
2880+
2881+
# list of styles
2882+
styles = list('rgcby')
2883+
axes = df.plot(style=styles, subplots=True)
2884+
for ax, c in zip(axes, styles):
2885+
self._check_colors(ax.get_lines(), linecolors=[c])
2886+
tm.close()
2887+
28392888
@slow
28402889
def test_area_colors(self):
28412890
from matplotlib import cm
@@ -2934,6 +2983,58 @@ def test_kde_colors(self):
29342983
rgba_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
29352984
self._check_colors(ax.get_lines(), linecolors=rgba_colors)
29362985

2986+
@slow
2987+
def test_kde_colors_and_styles_subplots(self):
2988+
tm._skip_if_no_scipy()
2989+
_skip_if_no_scipy_gaussian_kde()
2990+
2991+
from matplotlib import cm
2992+
default_colors = self.plt.rcParams.get('axes.color_cycle')
2993+
2994+
df = DataFrame(randn(5, 5))
2995+
2996+
axes = df.plot(kind='kde', subplots=True)
2997+
for ax, c in zip(axes, list(default_colors)):
2998+
self._check_colors(ax.get_lines(), linecolors=[c])
2999+
tm.close()
3000+
3001+
# single color
3002+
axes = df.plot(kind='kde', color='k', subplots=True)
3003+
for ax in axes:
3004+
self._check_colors(ax.get_lines(), linecolors=['k'])
3005+
tm.close()
3006+
3007+
custom_colors = 'rgcby'
3008+
axes = df.plot(kind='kde', color=custom_colors, subplots=True)
3009+
for ax, c in zip(axes, list(custom_colors)):
3010+
self._check_colors(ax.get_lines(), linecolors=[c])
3011+
tm.close()
3012+
3013+
rgba_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
3014+
for cmap in ['jet', cm.jet]:
3015+
axes = df.plot(kind='kde', colormap=cmap, subplots=True)
3016+
for ax, c in zip(axes, rgba_colors):
3017+
self._check_colors(ax.get_lines(), linecolors=[c])
3018+
tm.close()
3019+
3020+
# make color a list if plotting one column frame
3021+
# handles cases like df.plot(color='DodgerBlue')
3022+
axes = df.ix[:, [0]].plot(kind='kde', color='DodgerBlue', subplots=True)
3023+
self._check_colors(axes[0].lines, linecolors=['DodgerBlue'])
3024+
3025+
# single character style
3026+
axes = df.plot(kind='kde', style='r', subplots=True)
3027+
for ax in axes:
3028+
self._check_colors(ax.get_lines(), linecolors=['r'])
3029+
tm.close()
3030+
3031+
# list of styles
3032+
styles = list('rgcby')
3033+
axes = df.plot(kind='kde', style=styles, subplots=True)
3034+
for ax, c in zip(axes, styles):
3035+
self._check_colors(ax.get_lines(), linecolors=[c])
3036+
tm.close()
3037+
29373038
@slow
29383039
def test_boxplot_colors(self):
29393040

pandas/tools/plotting.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -1262,36 +1262,36 @@ def on_right(self, i):
12621262
if isinstance(self.secondary_y, (tuple, list, np.ndarray, Index)):
12631263
return self.data.columns[i] in self.secondary_y
12641264

1265-
def _get_style(self, i, col_name):
1266-
style = ''
1267-
if self.subplots:
1268-
style = 'k'
1265+
def _get_colors(self, num_colors=None, color_kwds='color'):
1266+
if num_colors is None:
1267+
num_colors = self.nseries
12691268

1269+
return _get_standard_colors(num_colors=num_colors,
1270+
colormap=self.colormap,
1271+
color=self.kwds.get(color_kwds))
1272+
1273+
def _apply_style_colors(self, colors, kwds, col_num, label):
1274+
"""
1275+
Manage style and color based on column number and its label.
1276+
Returns tuple of appropriate style and kwds which "color" may be added.
1277+
"""
1278+
style = None
12701279
if self.style is not None:
12711280
if isinstance(self.style, list):
12721281
try:
1273-
style = self.style[i]
1282+
style = self.style[col_num]
12741283
except IndexError:
12751284
pass
12761285
elif isinstance(self.style, dict):
1277-
style = self.style.get(col_name, style)
1286+
style = self.style.get(label, style)
12781287
else:
12791288
style = self.style
12801289

1281-
return style or None
1282-
1283-
def _get_colors(self, num_colors=None, color_kwds='color'):
1284-
if num_colors is None:
1285-
num_colors = self.nseries
1286-
1287-
return _get_standard_colors(num_colors=num_colors,
1288-
colormap=self.colormap,
1289-
color=self.kwds.get(color_kwds))
1290-
1291-
def _maybe_add_color(self, colors, kwds, style, i):
12921290
has_color = 'color' in kwds or self.colormap is not None
1293-
if has_color and (style is None or re.match('[a-z]+', style) is None):
1294-
kwds['color'] = colors[i % len(colors)]
1291+
nocolor_style = style is None or re.match('[a-z]+', style) is None
1292+
if (has_color or self.subplots) and nocolor_style:
1293+
kwds['color'] = colors[col_num % len(colors)]
1294+
return style, kwds
12951295

12961296
def _parse_errorbars(self, label, err):
12971297
'''
@@ -1612,9 +1612,8 @@ def _make_plot(self):
16121612
colors = self._get_colors()
16131613
for i, (label, y) in enumerate(it):
16141614
ax = self._get_ax(i)
1615-
style = self._get_style(i, label)
16161615
kwds = self.kwds.copy()
1617-
self._maybe_add_color(colors, kwds, style, i)
1616+
style, kwds = self._apply_style_colors(colors, kwds, i, label)
16181617

16191618
errors = self._get_errorbars(label=label, index=i)
16201619
kwds = dict(kwds, **errors)
@@ -1963,13 +1962,13 @@ def _make_plot(self):
19631962
colors = self._get_colors()
19641963
for i, (label, y) in enumerate(self._iter_data()):
19651964
ax = self._get_ax(i)
1966-
style = self._get_style(i, label)
1967-
label = com.pprint_thing(label)
19681965

19691966
kwds = self.kwds.copy()
1967+
1968+
label = com.pprint_thing(label)
19701969
kwds['label'] = label
1971-
self._maybe_add_color(colors, kwds, style, i)
19721970

1971+
style, kwds = self._apply_style_colors(colors, kwds, i, label)
19731972
if style is not None:
19741973
kwds['style'] = style
19751974

0 commit comments

Comments
 (0)