From 880a3d083f6ad8489385341d5032b902f76d11e3 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Mon, 31 Mar 2014 23:26:46 +0900 Subject: [PATCH] ENH: Scatter plot now supports errorbar --- doc/source/release.rst | 2 +- doc/source/v0.14.0.txt | 2 +- doc/source/visualization.rst | 5 +- pandas/tests/test_graphics.py | 155 ++++++++++++++++++++++++++-------- pandas/tools/plotting.py | 132 ++++++++++++++++------------- 5 files changed, 198 insertions(+), 98 deletions(-) diff --git a/doc/source/release.rst b/doc/source/release.rst index a0a96ebbd5c70..8853ce79f3d04 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -63,7 +63,7 @@ New features Date is used primarily in astronomy and represents the number of days from noon, January 1, 4713 BC. Because nanoseconds are used to define the time in pandas the actual range of dates that you can use is 1678 AD to 2262 AD. (:issue:`4041`) -- Added error bar support to the ``.plot`` method of ``DataFrame`` and ``Series`` (:issue:`3796`) +- Added error bar support to the ``.plot`` method of ``DataFrame`` and ``Series`` (:issue:`3796`, :issue:`6834`) - Implemented ``Panel.pct_change`` (:issue:`6904`) API Changes diff --git a/doc/source/v0.14.0.txt b/doc/source/v0.14.0.txt index a6caa075f6358..c70e32fd18694 100644 --- a/doc/source/v0.14.0.txt +++ b/doc/source/v0.14.0.txt @@ -365,7 +365,7 @@ Plotting - Hexagonal bin plots from ``DataFrame.plot`` with ``kind='hexbin'`` (:issue:`5478`), See :ref:`the docs`. - ``DataFrame.plot`` and ``Series.plot`` now supports area plot with specifying ``kind='area'`` (:issue:`6656`) -- Plotting with Error Bars is now supported in the ``.plot`` method of ``DataFrame`` and ``Series`` objects (:issue:`3796`), See :ref:`the docs`. +- Plotting with Error Bars is now supported in the ``.plot`` method of ``DataFrame`` and ``Series`` objects (:issue:`3796`, :issue:`6834`), See :ref:`the docs`. - ``DataFrame.plot`` and ``Series.plot`` now support a ``table`` keyword for plotting ``matplotlib.Table``, See :ref:`the docs`. - ``plot(legend='reverse')`` will now reverse the order of legend labels for most plot kinds. (:issue:`6014`) diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index 5255ddf3c33e7..8906e82eb937b 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -394,10 +394,13 @@ x and y errorbars are supported and be supplied using the ``xerr`` and ``yerr`` - As a ``DataFrame`` or ``dict`` of errors with column names matching the ``columns`` attribute of the plotting ``DataFrame`` or matching the ``name`` attribute of the ``Series`` - As a ``str`` indicating which of the columns of plotting ``DataFrame`` contain the error values -- As raw values (``list``, ``tuple``, or ``np.ndarray``). Must be the same length as the plotting ``DataFrame``/``Series`` +- As list-like raw values (``list``, ``tuple``, or ``np.ndarray``). Must be the same length as the plotting ``DataFrame``/``Series`` +- As float. The error value will be applied to all data. Asymmetrical error bars are also supported, however raw error values must be provided in this case. For a ``M`` length ``Series``, a ``Mx2`` array should be provided indicating lower and upper (or left and right) errors. For a ``MxN`` ``DataFrame``, asymmetrical errors should be in a ``Mx2xN`` array. +**Note**: Plotting ``xerr`` is not supported in time series. + Here is an example of one way to easily plot group means with standard deviations from the raw data. .. ipython:: python diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 8b79c9e9d1307..0186ac4c2b74b 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# coding: utf-8 + import nose import os import string @@ -27,7 +30,6 @@ def _skip_if_no_scipy(): except ImportError: raise nose.SkipTest("no scipy") - @tm.mplskip class TestSeriesPlots(tm.TestCase): def setUp(self): @@ -315,24 +317,36 @@ def test_dup_datetime_index_plot(self): @slow def test_errorbar_plot(self): - s = Series(np.arange(10)) + s = Series(np.arange(10), name='x') s_err = np.random.randn(10) - + d_err = DataFrame(randn(10, 2), index=s.index, columns=['x', 'y']) # test line and bar plots kinds = ['line', 'bar'] for kind in kinds: - _check_plot_works(s.plot, yerr=Series(s_err), kind=kind) - _check_plot_works(s.plot, yerr=s_err, kind=kind) - _check_plot_works(s.plot, yerr=s_err.tolist(), kind=kind) - - _check_plot_works(s.plot, xerr=s_err) + ax = _check_plot_works(s.plot, yerr=Series(s_err), kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(s.plot, yerr=s_err, kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(s.plot, yerr=s_err.tolist(), kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(s.plot, yerr=d_err, kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(s.plot, xerr=0.2, yerr=0.2, kind=kind) + _check_has_errorbars(self, ax, xerr=1, yerr=1) + + ax = _check_plot_works(s.plot, xerr=s_err) + _check_has_errorbars(self, ax, xerr=1, yerr=0) # test time series plotting ix = date_range('1/1/2000', '1/1/2001', freq='M') - ts = Series(np.arange(12), index=ix) + ts = Series(np.arange(12), index=ix, name='x') ts_err = Series(np.random.randn(12), index=ix) + td_err = DataFrame(randn(12, 2), index=ix, columns=['x', 'y']) - _check_plot_works(ts.plot, yerr=ts_err) + ax = _check_plot_works(ts.plot, yerr=ts_err) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(ts.plot, yerr=td_err) + _check_has_errorbars(self, ax, xerr=0, yerr=1) # check incorrect lengths and types with tm.assertRaises(ValueError): @@ -1505,27 +1519,51 @@ def test_errorbar_plot(self): df_err = DataFrame(d_err) # check line plots - _check_plot_works(df.plot, yerr=df_err, logy=True) - _check_plot_works(df.plot, yerr=df_err, logx=True, logy=True) - _check_plot_works(df.plot, yerr=df_err, loglog=True) + ax = _check_plot_works(df.plot, yerr=df_err, logy=True) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(df.plot, yerr=df_err, logx=True, logy=True) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(df.plot, yerr=df_err, loglog=True) + _check_has_errorbars(self, ax, xerr=0, yerr=2) kinds = ['line', 'bar', 'barh'] for kind in kinds: - _check_plot_works(df.plot, yerr=df_err['x'], kind=kind) - _check_plot_works(df.plot, yerr=d_err, kind=kind) - _check_plot_works(df.plot, yerr=df_err, xerr=df_err, kind=kind) - _check_plot_works(df.plot, yerr=df_err['x'], xerr=df_err['x'], kind=kind) - _check_plot_works(df.plot, yerr=df_err, xerr=df_err, subplots=True, kind=kind) + ax = _check_plot_works(df.plot, yerr=df_err['x'], kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(df.plot, yerr=d_err, kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(df.plot, yerr=df_err, xerr=df_err, kind=kind) + _check_has_errorbars(self, ax, xerr=2, yerr=2) + ax = _check_plot_works(df.plot, yerr=df_err['x'], xerr=df_err['x'], kind=kind) + _check_has_errorbars(self, ax, xerr=2, yerr=2) + ax = _check_plot_works(df.plot, xerr=0.2, yerr=0.2, kind=kind) + _check_has_errorbars(self, ax, xerr=2, yerr=2) + axes = _check_plot_works(df.plot, yerr=df_err, xerr=df_err, subplots=True, kind=kind) + for ax in axes: + _check_has_errorbars(self, ax, xerr=1, yerr=1) - _check_plot_works((df+1).plot, yerr=df_err, xerr=df_err, kind='bar', log=True) + ax = _check_plot_works((df+1).plot, yerr=df_err, xerr=df_err, kind='bar', log=True) + _check_has_errorbars(self, ax, xerr=2, yerr=2) # yerr is raw error values - _check_plot_works(df['y'].plot, yerr=np.ones(12)*0.4) - _check_plot_works(df.plot, yerr=np.ones((2, 12))*0.4) + ax = _check_plot_works(df['y'].plot, yerr=np.ones(12)*0.4) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(df.plot, yerr=np.ones((2, 12))*0.4) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + + # yerr is iterator + import itertools + ax = _check_plot_works(df.plot, yerr=itertools.repeat(0.1, len(df))) + _check_has_errorbars(self, ax, xerr=0, yerr=2) # yerr is column name - df['yerr'] = np.ones(12)*0.2 - _check_plot_works(df.plot, y='y', x='x', yerr='yerr') + for yerr in ['yerr', u('誤差')]: + s_df = df.copy() + s_df[yerr] = np.ones(12)*0.2 + ax = _check_plot_works(s_df.plot, yerr=yerr) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(s_df.plot, y='y', x='x', yerr=yerr) + _check_has_errorbars(self, ax, xerr=0, yerr=1) with tm.assertRaises(ValueError): df.plot(yerr=np.random.randn(11)) @@ -1539,8 +1577,10 @@ def test_errorbar_with_integer_column_names(self): # test with integer column names df = DataFrame(np.random.randn(10, 2)) df_err = DataFrame(np.random.randn(10, 2)) - _check_plot_works(df.plot, yerr=df_err) - _check_plot_works(df.plot, y=0, yerr=1) + ax = _check_plot_works(df.plot, yerr=df_err) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(df.plot, y=0, yerr=1) + _check_has_errorbars(self, ax, xerr=0, yerr=1) @slow def test_errorbar_with_partial_columns(self): @@ -1548,12 +1588,22 @@ def test_errorbar_with_partial_columns(self): df_err = DataFrame(np.random.randn(10, 2), columns=[0, 2]) kinds = ['line', 'bar'] for kind in kinds: - _check_plot_works(df.plot, yerr=df_err, kind=kind) + ax = _check_plot_works(df.plot, yerr=df_err, kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=2) ix = date_range('1/1/2000', periods=10, freq='M') df.set_index(ix, inplace=True) df_err.set_index(ix, inplace=True) - _check_plot_works(df.plot, yerr=df_err, kind='line') + ax = _check_plot_works(df.plot, yerr=df_err, kind='line') + _check_has_errorbars(self, ax, xerr=0, yerr=2) + + d = {'x': np.arange(12), 'y': np.arange(12, 0, -1)} + df = DataFrame(d) + d_err = {'x': np.ones(12)*0.2, 'z': np.ones(12)*0.4} + df_err = DataFrame(d_err) + for err in [d_err, df_err]: + ax = _check_plot_works(df.plot, yerr=err) + _check_has_errorbars(self, ax, xerr=0, yerr=1) @slow def test_errorbar_timeseries(self): @@ -1568,13 +1618,19 @@ def test_errorbar_timeseries(self): kinds = ['line', 'bar', 'barh'] for kind in kinds: - _check_plot_works(tdf.plot, yerr=tdf_err, kind=kind) - _check_plot_works(tdf.plot, yerr=d_err, kind=kind) - _check_plot_works(tdf.plot, y='y', kind=kind) - _check_plot_works(tdf.plot, y='y', yerr='x', kind=kind) - _check_plot_works(tdf.plot, yerr=tdf_err, kind=kind) - _check_plot_works(tdf.plot, kind=kind, subplots=True) - + ax = _check_plot_works(tdf.plot, yerr=tdf_err, kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(tdf.plot, yerr=d_err, kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + ax = _check_plot_works(tdf.plot, y='y', yerr=tdf_err['x'], kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(tdf.plot, y='y', yerr='x', kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(tdf.plot, yerr=tdf_err, kind=kind) + _check_has_errorbars(self, ax, xerr=0, yerr=2) + axes = _check_plot_works(tdf.plot, kind=kind, yerr=tdf_err, subplots=True) + for ax in axes: + _check_has_errorbars(self, ax, xerr=0, yerr=1) def test_errorbar_asymmetrical(self): @@ -1608,6 +1664,21 @@ def test_table(self): plotting.table(ax, df.T) self.assert_(len(ax.tables) == 1) + def test_errorbar_scatter(self): + df = DataFrame(np.random.randn(5, 2), index=range(5), columns=['x', 'y']) + df_err = DataFrame(np.random.randn(5, 2) / 5, + index=range(5), columns=['x', 'y']) + + ax = _check_plot_works(df.plot, kind='scatter', x='x', y='y') + _check_has_errorbars(self, ax, xerr=0, yerr=0) + ax = _check_plot_works(df.plot, kind='scatter', x='x', y='y', xerr=df_err) + _check_has_errorbars(self, ax, xerr=1, yerr=0) + ax = _check_plot_works(df.plot, kind='scatter', x='x', y='y', yerr=df_err) + _check_has_errorbars(self, ax, xerr=0, yerr=1) + ax = _check_plot_works(df.plot, kind='scatter', x='x', y='y', + xerr=df_err, yerr=df_err) + _check_has_errorbars(self, ax, xerr=1, yerr=1) + @tm.mplskip class TestDataFrameGroupByPlots(tm.TestCase): @@ -1803,8 +1874,24 @@ def assert_is_valid_plot_return_object(objs): ''.format(objs.__class__.__name__)) +def _check_has_errorbars(t, ax, xerr=0, yerr=0): + containers = ax.containers + xerr_count = 0 + yerr_count = 0 + for c in containers: + has_xerr = getattr(c, 'has_xerr', False) + has_yerr = getattr(c, 'has_yerr', False) + if has_xerr: + xerr_count += 1 + if has_yerr: + yerr_count += 1 + t.assertEqual(xerr, xerr_count) + t.assertEqual(yerr, yerr_count) + + def _check_plot_works(f, *args, **kwargs): import matplotlib.pyplot as plt + ret = None try: try: diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index ab3717d52e4f2..55aa01fd2e265 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -16,7 +16,7 @@ from pandas.tseries.period import PeriodIndex, Period from pandas.tseries.frequencies import get_period_alias, get_base_alias from pandas.tseries.offsets import DateOffset -from pandas.compat import range, lrange, lmap, map, zip +from pandas.compat import range, lrange, lmap, map, zip, string_types import pandas.compat as compat try: # mpl optional @@ -837,9 +837,11 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, self.axes = None # parse errorbar input if given - for err_dim in 'xy': - if err_dim+'err' in kwds: - kwds[err_dim+'err'] = self._parse_errorbars(error_dim=err_dim, **kwds) + xerr = kwds.pop('xerr', None) + yerr = kwds.pop('yerr', None) + self.errors = {} + for kw, err in zip(['xerr', 'yerr'], [xerr, yerr]): + self.errors[kw] = self._parse_errorbars(kw, err) if not isinstance(secondary_y, (bool, tuple, list, np.ndarray)): secondary_y = [secondary_y] @@ -1185,8 +1187,7 @@ def _get_plot_function(self): the presence of errorbar keywords. ''' - if ('xerr' not in self.kwds) and \ - ('yerr' not in self.kwds): + if all(e is None for e in self.errors.values()): plotf = self.plt.Axes.plot else: plotf = self.plt.Axes.errorbar @@ -1266,7 +1267,7 @@ def _maybe_add_color(self, colors, kwds, style, i): if has_color and (style is None or re.match('[a-z]+', style) is None): kwds['color'] = colors[i % len(colors)] - def _parse_errorbars(self, error_dim='y', **kwds): + def _parse_errorbars(self, label, err): ''' Look for error keyword arguments and return the actual errorbar data or return the error DataFrame/dict @@ -1280,47 +1281,48 @@ def _parse_errorbars(self, error_dim='y', **kwds): str: the name of the column within the plotted DataFrame ''' - err_kwd = kwds.pop(error_dim+'err', None) - if err_kwd is None: + if err is None: return None from pandas import DataFrame, Series - def match_labels(data, err): - err = err.reindex_axis(data.index) - return err + def match_labels(data, e): + e = e.reindex_axis(data.index) + return e # key-matched DataFrame - if isinstance(err_kwd, DataFrame): - err = err_kwd - err = match_labels(self.data, err) + if isinstance(err, DataFrame): + err = match_labels(self.data, err) # key-matched dict - elif isinstance(err_kwd, dict): - err = err_kwd + elif isinstance(err, dict): + pass # Series of error values - elif isinstance(err_kwd, Series): + elif isinstance(err, Series): # broadcast error series across data - err = match_labels(self.data, err_kwd) + err = match_labels(self.data, err) err = np.atleast_2d(err) err = np.tile(err, (self.nseries, 1)) # errors are a column in the dataframe - elif isinstance(err_kwd, str): - err = np.atleast_2d(self.data[err_kwd].values) - self.data = self.data[self.data.columns.drop(err_kwd)] + elif isinstance(err, string_types): + evalues = self.data[err].values + self.data = self.data[self.data.columns.drop(err)] + err = np.atleast_2d(evalues) err = np.tile(err, (self.nseries, 1)) - elif isinstance(err_kwd, (tuple, list, np.ndarray)): - - # raw error values - err = np.atleast_2d(err_kwd) + elif com.is_list_like(err): + if com.is_iterator(err): + err = np.atleast_2d(list(err)) + else: + # raw error values + err = np.atleast_2d(err) err_shape = err.shape # asymmetrical error bars - if err.ndim==3: + if err.ndim == 3: if (err_shape[0] != self.nseries) or \ (err_shape[1] != 2) or \ (err_shape[2] != len(self.data)): @@ -1330,15 +1332,39 @@ def match_labels(data, err): raise ValueError(msg) # broadcast errors to each data series - if len(err)==1: + if len(err) == 1: err = np.tile(err, (self.nseries, 1)) + elif com.is_number(err): + err = np.tile([err], (self.nseries, len(self.data))) + else: - msg = "No valid %serr detected" % error_dim + msg = "No valid %s detected" % label raise ValueError(msg) return err + def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True): + from pandas import DataFrame + errors = {} + + for kw, flag in zip(['xerr', 'yerr'], [xerr, yerr]): + if flag: + err = self.errors[kw] + # user provided label-matched dataframe of errors + if isinstance(err, (DataFrame, dict)): + if label is not None and label in err.keys(): + err = err[label] + else: + err = None + elif index is not None and err is not None: + err = err[index] + + if err is not None: + errors[kw] = err + return errors + + class KdePlot(MPLPlot): def __init__(self, data, bw_method=None, ind=None, **kwargs): MPLPlot.__init__(self, data, **kwargs) @@ -1418,6 +1444,14 @@ def _make_plot(self): self._add_legend_handle(scatter, label) + errors_x = self._get_errorbars(label=x, index=0, yerr=False) + errors_y = self._get_errorbars(label=y, index=1, xerr=False) + if len(errors_x) > 0 or len(errors_y) > 0: + err_kwds = dict(errors_x, **errors_y) + if 'color' in self.kwds: + err_kwds['color'] = self.kwds['color'] + ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds) + def _post_plot_logic(self): ax = self.axes[0] x, y = self.x, self.y @@ -1558,16 +1592,9 @@ def _make_plot(self): kwds = self.kwds.copy() self._maybe_add_color(colors, kwds, style, i) - for err_kw in ['xerr', 'yerr']: - # user provided label-matched dataframe of errors - if err_kw in kwds: - if isinstance(kwds[err_kw], (DataFrame, dict)): - if label in kwds[err_kw].keys(): - kwds[err_kw] = kwds[err_kw][label] - else: del kwds[err_kw] - elif kwds[err_kw] is not None: - kwds[err_kw] = kwds[err_kw][i] - + errors = self._get_errorbars(label=label, index=i) + kwds = dict(kwds, **errors) + label = com.pprint_thing(label) # .encode('utf-8') kwds['label'] = label @@ -1629,7 +1656,6 @@ def _plot(data, ax, label, style, **kwds): return _plot def _make_ts_plot(self, data, **kwargs): - from pandas.core.frame import DataFrame colors = self._get_colors() plotf = self._get_ts_plot_function() @@ -1641,15 +1667,8 @@ def _make_ts_plot(self, data, **kwargs): self._maybe_add_color(colors, kwds, style, i) - # key-matched DataFrame of errors - if 'yerr' in kwds: - yerr = kwds['yerr'] - if isinstance(yerr, (DataFrame, dict)): - if label in yerr.keys(): - kwds['yerr'] = yerr[label] - else: del kwds['yerr'] - else: - kwds['yerr'] = yerr[i] + errors = self._get_errorbars(label=label, index=i, xerr=False) + kwds = dict(kwds, **errors) label = com.pprint_thing(label) @@ -1833,8 +1852,6 @@ def f(ax, x, y, w, start=None, log=self.log, **kwds): def _make_plot(self): import matplotlib as mpl - from pandas import DataFrame, Series - # mpl decided to make their version string unicode across all Python # versions for mpl >= 1.3 so we have to call str here for python 2 mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1') @@ -1853,15 +1870,8 @@ def _make_plot(self): kwds = self.kwds.copy() kwds['color'] = colors[i % ncolors] - for err_kw in ['xerr', 'yerr']: - if err_kw in kwds: - # user provided label-matched dataframe of errors - if isinstance(kwds[err_kw], (DataFrame, dict)): - if label in kwds[err_kw].keys(): - kwds[err_kw] = kwds[err_kw][label] - else: del kwds[err_kw] - elif kwds[err_kw] is not None: - kwds[err_kw] = kwds[err_kw][i] + errors = self._get_errorbars(label=label, index=i) + kwds = dict(kwds, **errors) label = com.pprint_thing(label) @@ -2074,7 +2084,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, for kw in ['xerr', 'yerr']: if (kw in kwds) and \ - (isinstance(kwds[kw], str) or com.is_integer(kwds[kw])): + (isinstance(kwds[kw], string_types) or com.is_integer(kwds[kw])): try: kwds[kw] = frame[kwds[kw]] except (IndexError, KeyError, TypeError):