diff --git a/doc/source/release.rst b/doc/source/release.rst index b2eefda10fccc..0fa7b4b2ed5f2 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -52,6 +52,8 @@ pandas 0.12 - A ``filter`` method on grouped Series or DataFrames returns a subset of the original (:issue:`3680`, :issue:`919`) - Access to historical Google Finance data in pandas.io.data (:issue:`3814`) + - DataFrame plotting methods can sample column colors from a Matplotlib + colormap via the ``colormap`` keyword. (:issue:`3860`) **Improvements to existing features** diff --git a/doc/source/v0.12.0.txt b/doc/source/v0.12.0.txt index 643ef7ddbbab4..4b100ed0b5fab 100644 --- a/doc/source/v0.12.0.txt +++ b/doc/source/v0.12.0.txt @@ -96,6 +96,12 @@ API changes and thus you should cast to an appropriate numeric dtype if you need to plot something. + - Add ``colormap`` keyword to DataFrame plotting methods. Accepts either a + matplotlib colormap object (ie, matplotlib.cm.jet) or a string name of such + an object (ie, 'jet'). The colormap is sampled to select the color for each + column. Please see :ref:`visualization.colormaps` for more information. + (:issue:`3860`) + - ``DataFrame.interpolate()`` is now deprecated. Please use ``DataFrame.fillna()`` and ``DataFrame.replace()`` instead. (:issue:`3582`, :issue:`3675`, :issue:`3676`) diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index f0790396a5c39..f1a9880047691 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -531,3 +531,65 @@ be colored differently. @savefig radviz.png width=6in radviz(data, 'Name') + +.. _visualization.colormaps: + +Colormaps +~~~~~~~~~ + +A potential issue when plotting a large number of columns is that it can be difficult to distinguish some series due to repetition in the default colors. To remedy this, DataFrame plotting supports the use of the ``colormap=`` argument, which accepts either a Matplotlib `colormap `__ or a string that is a name of a colormap registered with Matplotlib. A visualization of the default matplotlib colormaps is available `here `__. + +As matplotlib does not directly support colormaps for line-based plots, the colors are selected based on an even spacing determined by the number of columns in the DataFrame. There is no consideration made for background color, so some colormaps will produce lines that are not easily visible. + +To use the jet colormap, we can simply pass ``'jet'`` to ``colormap=`` + +.. ipython:: python + + df = DataFrame(randn(1000, 10), index=ts.index) + df = df.cumsum() + + plt.figure() + + @savefig jet.png width=6in + df.plot(colormap='jet') + +or we can pass the colormap itself + +.. ipython:: python + + from matplotlib import cm + + plt.figure() + + @savefig jet_cm.png width=6in + df.plot(colormap=cm.jet) + +Colormaps can also be used other plot types, like bar charts: + +.. ipython:: python + + dd = DataFrame(randn(10, 10)).applymap(abs) + dd = dd.cumsum() + + plt.figure() + + @savefig greens.png width=6in + dd.plot(kind='bar', colormap='Greens') + +Parallel coordinates charts: + +.. ipython:: python + + plt.figure() + + @savefig parallel_gist_rainbow.png width=6in + parallel_coordinates(data, 'Name', colormap='gist_rainbow') + +Andrews curves charts: + +.. ipython:: python + + plt.figure() + + @savefig andrews_curve_winter.png width=6in + andrews_curves(data, 'Name', colormap='winter') diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index e57e5a9af2fc0..d094e8b99d9cb 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -103,6 +103,35 @@ def test_bar_colors(self): self.assert_(xp == rs) plt.close('all') + + from matplotlib import cm + + # Test str -> colormap functionality + ax = df.plot(kind='bar', colormap='jet') + + rects = ax.patches + + rgba_colors = map(cm.jet, np.linspace(0, 1, 5)) + for i, rect in enumerate(rects[::5]): + xp = rgba_colors[i] + rs = rect.get_facecolor() + self.assert_(xp == rs) + + plt.close('all') + + # Test colormap functionality + ax = df.plot(kind='bar', colormap=cm.jet) + + rects = ax.patches + + rgba_colors = map(cm.jet, np.linspace(0, 1, 5)) + for i, rect in enumerate(rects[::5]): + xp = rgba_colors[i] + rs = rect.get_facecolor() + self.assert_(xp == rs) + + plt.close('all') + df.ix[:, [0]].plot(kind='bar', color='DodgerBlue') @slow @@ -600,6 +629,7 @@ def test_andrews_curves(self): def test_parallel_coordinates(self): from pandas import read_csv from pandas.tools.plotting import parallel_coordinates + from matplotlib import cm path = os.path.join(curpath(), 'data/iris.csv') df = read_csv(path) _check_plot_works(parallel_coordinates, df, 'Name') @@ -611,6 +641,7 @@ def test_parallel_coordinates(self): colors=('#556270', '#4ECDC4', '#C7F464')) _check_plot_works(parallel_coordinates, df, 'Name', colors=['dodgerblue', 'aquamarine', 'seagreen']) + _check_plot_works(parallel_coordinates, df, 'Name', colormap=cm.jet) df = read_csv( path, header=None, skiprows=1, names=[1, 2, 4, 8, 'Name']) @@ -622,9 +653,11 @@ def test_parallel_coordinates(self): def test_radviz(self): from pandas import read_csv from pandas.tools.plotting import radviz + from matplotlib import cm path = os.path.join(curpath(), 'data/iris.csv') df = read_csv(path) _check_plot_works(radviz, df, 'Name') + _check_plot_works(radviz, df, 'Name', colormap=cm.jet) @slow def test_plot_int_columns(self): @@ -666,6 +699,7 @@ def test_line_colors(self): import matplotlib.pyplot as plt import sys from StringIO import StringIO + from matplotlib import cm custom_colors = 'rgcby' @@ -691,6 +725,30 @@ def test_line_colors(self): finally: sys.stderr = tmp + plt.close('all') + + ax = df.plot(colormap='jet') + + rgba_colors = map(cm.jet, np.linspace(0, 1, len(df))) + + lines = ax.get_lines() + for i, l in enumerate(lines): + xp = rgba_colors[i] + rs = l.get_color() + self.assert_(xp == rs) + + plt.close('all') + + ax = df.plot(colormap=cm.jet) + + rgba_colors = map(cm.jet, np.linspace(0, 1, len(df))) + + lines = ax.get_lines() + for i, l in enumerate(lines): + xp = rgba_colors[i] + rs = l.get_color() + self.assert_(xp == rs) + # make color a list if plotting one column frame # handles cases like df.plot(color='DodgerBlue') plt.close('all') @@ -862,6 +920,10 @@ def test_option_mpl_style(self): except ValueError: pass + def test_invalid_colormap(self): + df = DataFrame(np.random.randn(500, 2), columns=['A', 'B']) + + self.assertRaises(ValueError, df.plot, colormap='invalid_colormap') def _check_plot_works(f, *args, **kwargs): import matplotlib.pyplot as plt diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index a5aaac05d8ad8..8abe9df5ddd56 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -91,6 +91,43 @@ def _get_standard_kind(kind): return {'density': 'kde'}.get(kind, kind) +def _get_standard_colors(num_colors=None, colormap=None, + color_type='default', color=None): + import matplotlib.pyplot as plt + + if color is None and colormap is not None: + if isinstance(colormap, basestring): + import matplotlib.cm as cm + colormap = cm.get_cmap(colormap) + colors = map(colormap, np.linspace(0, 1, num=num_colors)) + elif color is not None: + if colormap is not None: + warnings.warn("'color' and 'colormap' cannot be used " + "simultaneously. Using 'color'") + colors = color + else: + if color_type == 'default': + colors = plt.rcParams.get('axes.color_cycle', list('bgrcmyk')) + if isinstance(colors, basestring): + colors = list(colors) + elif color_type == 'random': + import random + def random_color(column): + random.seed(column) + return [random.random() for _ in range(3)] + + colors = map(random_color, range(num_colors)) + else: + raise NotImplementedError + + if len(colors) != num_colors: + multiple = num_colors//len(colors) - 1 + mod = num_colors % len(colors) + + colors += multiple * colors + colors += colors[:mod] + + return colors class _Options(dict): """ @@ -283,7 +320,7 @@ def _get_marker_compat(marker): return marker -def radviz(frame, class_column, ax=None, **kwds): +def radviz(frame, class_column, ax=None, colormap=None, **kwds): """RadViz - a multivariate data visualization algorithm Parameters: @@ -291,6 +328,9 @@ def radviz(frame, class_column, ax=None, **kwds): frame: DataFrame object class_column: Column name that contains information about class membership ax: Matplotlib axis object, optional + colormap : str or matplotlib colormap object, default None + Colormap to select colors from. If string, load colormap with that name + from matplotlib. kwds: Matplotlib scatter method keyword arguments, optional Returns: @@ -302,10 +342,6 @@ def radviz(frame, class_column, ax=None, **kwds): import matplotlib.text as text import random - def random_color(column): - random.seed(column) - return [random.random() for _ in range(3)] - def normalize(series): a = min(series) b = max(series) @@ -322,6 +358,9 @@ def normalize(series): classes = set(frame[class_column]) to_plot = {} + colors = _get_standard_colors(num_colors=len(classes), colormap=colormap, + color_type='random', color=kwds.get('color')) + for class_ in classes: to_plot[class_] = [[], []] @@ -338,10 +377,10 @@ def normalize(series): to_plot[class_name][0].append(y[0]) to_plot[class_name][1].append(y[1]) - for class_ in classes: + for i, class_ in enumerate(classes): line = ax.scatter(to_plot[class_][0], to_plot[class_][1], - color=random_color(class_), + color=colors[i], label=com.pprint_thing(class_), **kwds) ax.legend() @@ -368,7 +407,8 @@ def normalize(series): return ax -def andrews_curves(data, class_column, ax=None, samples=200): +def andrews_curves(data, class_column, ax=None, samples=200, colormap=None, + **kwds): """ Parameters: ----------- @@ -377,6 +417,10 @@ def andrews_curves(data, class_column, ax=None, samples=200): class_column : Name of the column containing class names ax : matplotlib axes object, default None samples : Number of points to plot in each curve + colormap : str or matplotlib colormap object, default None + Colormap to select colors from. If string, load colormap with that name + from matplotlib. + kwds : Optional plotting arguments to be passed to matplotlib Returns: -------- @@ -401,15 +445,17 @@ def f(x): return result return f - def random_color(column): - random.seed(column) - return [random.random() for _ in range(3)] + n = len(data) classes = set(data[class_column]) class_col = data[class_column] columns = [data[col] for col in data.columns if (col != class_column)] x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)] used_legends = set([]) + + colors = _get_standard_colors(num_colors=n, colormap=colormap, + color_type='random', color=kwds.get('color')) + if ax is None: ax = plt.gca(xlim=(-pi, pi)) for i in range(n): @@ -420,9 +466,9 @@ def random_color(column): if com.pprint_thing(class_col[i]) not in used_legends: label = com.pprint_thing(class_col[i]) used_legends.add(label) - ax.plot(x, y, color=random_color(class_col[i]), label=label) + ax.plot(x, y, color=colors[i], label=label, **kwds) else: - ax.plot(x, y, color=random_color(class_col[i])) + ax.plot(x, y, color=colors[i], **kwds) ax.legend(loc='upper right') ax.grid() @@ -492,7 +538,7 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds): def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None, - use_columns=False, xticks=None, **kwds): + use_columns=False, xticks=None, colormap=None, **kwds): """Parallel coordinates plotting. Parameters @@ -511,6 +557,8 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None, If true, columns will be used as xticks xticks: list or tuple, optional A list of values to use for xticks + colormap: str or matplotlib colormap, default None + Colormap to use for line colors. kwds: list, optional A list of keywords for matplotlib plot method @@ -530,9 +578,7 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None, import matplotlib.pyplot as plt import random - def random_color(column): - random.seed(column) - return [random.random() for _ in range(3)] + n = len(data) classes = set(data[class_column]) class_col = data[class_column] @@ -563,13 +609,11 @@ def random_color(column): if ax is None: ax = plt.gca() - # if user has not specified colors to use, choose at random - if colors is None: - colors = dict((kls, random_color(kls)) for kls in classes) - else: - if len(colors) != len(classes): - raise ValueError('Number of colors must match number of classes') - colors = dict((kls, colors[i]) for i, kls in enumerate(classes)) + color_values = _get_standard_colors(num_colors=len(classes), + colormap=colormap, color_type='random', + color=colors) + + colors = dict(zip(classes, color_values)) for i in range(n): row = df.irow(i).values @@ -714,7 +758,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, ax=None, fig=None, title=None, xlim=None, ylim=None, xticks=None, yticks=None, sort_columns=False, fontsize=None, - secondary_y=False, **kwds): + secondary_y=False, colormap=None, **kwds): self.data = data self.by = by @@ -756,6 +800,8 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, secondary_y = [secondary_y] self.secondary_y = secondary_y + self.colormap = colormap + self.kwds = kwds self._validate_color_args() @@ -774,6 +820,11 @@ def _validate_color_args(self): # support series.plot(color='green') self.kwds['color'] = [self.kwds['color']] + if ('color' in self.kwds or 'colors' in self.kwds) and \ + self.colormap is not None: + warnings.warn("'color' and 'colormap' cannot be used " + "simultaneously. Using 'color'") + def _iter_data(self): from pandas.core.frame import DataFrame if isinstance(self.data, (Series, np.ndarray)): @@ -1072,15 +1123,18 @@ def _get_style(self, i, col_name): return style or None def _get_colors(self): - import matplotlib.pyplot as plt - cycle = plt.rcParams.get('axes.color_cycle', list('bgrcmyk')) - if isinstance(cycle, basestring): - cycle = list(cycle) - colors = self.kwds.get('color', cycle) - return colors + from pandas.core.frame import DataFrame + if isinstance(self.data, DataFrame): + num_colors = len(self.data.columns) + else: + num_colors = 1 + + return _get_standard_colors(num_colors=num_colors, + colormap=self.colormap, + color=self.kwds.get('color')) def _maybe_add_color(self, colors, kwds, style, i): - has_color = 'color' in kwds + has_color = 'color' in kwds or self.colormap is not None if has_color and (style is None or re.match('[a-z]+', style) is None): kwds['color'] = colors[i % len(colors)] @@ -1090,6 +1144,7 @@ def _get_marked_label(self, label, col_num): else: return label + class KdePlot(MPLPlot): def __init__(self, data, **kwargs): MPLPlot.__init__(self, data, **kwargs) @@ -1389,15 +1444,6 @@ def f(ax, x, y, w, start=None, log=self.log, **kwds): return f - def _get_colors(self): - import matplotlib.pyplot as plt - cycle = plt.rcParams.get('axes.color_cycle', list('bgrcmyk')) - if isinstance(cycle, basestring): - cycle = list(cycle) - has_colors = 'color' in self.kwds - colors = self.kwds.get('color', cycle) - return colors - def _make_plot(self): import matplotlib as mpl colors = self._get_colors() @@ -1547,6 +1593,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, mark_right: boolean, default True When using a secondary_y axis, should the legend label the axis of the various columns automatically + colormap : str or matplotlib colormap object, default None + Colormap to select colors from. If string, load colormap with that name + from matplotlib. kwds : keywords Options to pass to matplotlib plotting method @@ -1724,12 +1773,7 @@ def boxplot(data, column=None, by=None, ax=None, fontsize=None, def _get_colors(): - import matplotlib.pyplot as plt - cycle = plt.rcParams.get('axes.color_cycle', list('bgrcmyk')) - if isinstance(cycle, basestring): - cycle = list(cycle) - colors = kwds.get('color', cycle) - return colors + return _get_standard_colors(color=kwds.get('color'), num_colors=1) def maybe_color_bp(bp): if 'color' not in kwds :