diff --git a/doc/source/release.rst b/doc/source/release.rst index 7cf2bec0f4144..b91e307bf7c69 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -59,6 +59,8 @@ New features Date is used primarily in astronomy and represents the number of days from noon, January 1, 4713 BC. Because nanoseconds are used to define the time in pandas the actual range of dates that you can use is 1678 AD to 2262 AD. (:issue:`4041`) +- Added error bar support to the ``.plot`` method of ``DataFrame`` and ``Series`` (:issue:`3796`) + API Changes ~~~~~~~~~~~ @@ -126,9 +128,9 @@ API Changes DataFrame returned by ``GroupBy.apply`` (:issue:`6124`). This facilitates ``DataFrame.stack`` operations where the name of the column index is used as the name of the inserted column containing the pivoted data. - -- The :func:`pivot_table`/:meth:`DataFrame.pivot_table` and :func:`crosstab` functions - now take arguments ``index`` and ``columns`` instead of ``rows`` and ``cols``. A + +- The :func:`pivot_table`/:meth:`DataFrame.pivot_table` and :func:`crosstab` functions + now take arguments ``index`` and ``columns`` instead of ``rows`` and ``cols``. A ``FutureWarning`` is raised to alert that the old ``rows`` and ``cols`` arguments will not be supported in a future release (:issue:`5505`) diff --git a/doc/source/v0.14.0.txt b/doc/source/v0.14.0.txt index ea321cbab545a..463e4f2a3a49c 100644 --- a/doc/source/v0.14.0.txt +++ b/doc/source/v0.14.0.txt @@ -286,6 +286,20 @@ You can use a right-hand-side of an alignable object as well. df2.loc[idx[:,:,['C1','C3']],:] = df2*1000 df2 +Plotting With Errorbars +~~~~~~~~~~~~~~~~~~~~~~~ + +Plotting with error bars is now supported in the ``.plot`` method of ``DataFrame`` and ``Series`` objects (:issue:`3796`). + +x and y errorbars are supported and can be supplied using the ``xerr`` and ``yerr`` keyword arguments to ``.plot()`` The error values can be specified using a variety of formats. + +- As a ``DataFrame`` or ``dict`` of errors with one or more of the column names (or dictionary keys) matching one or more of the column names of the plotting ``DataFrame`` or matching the ``name`` attribute of the ``Series`` +- As a ``str`` indicating which of the columns of plotting ``DataFrame`` contain the error values +- As raw values (``list``, ``tuple``, or ``np.ndarray``). Must be the same length as the plotting ``DataFrame``/``Series`` + +Asymmetrical error bars are also supported, however raw error values must be provided in this case. For a ``M`` length ``Series``, a ``Mx2`` array should be provided indicating lower and upper (or left and right) errors. For a ``MxN`` ``DataFrame``, asymmetrical errors should be in a ``Mx2xN`` array. + + Prior Version Deprecations/Changes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index 5827f2e971e42..bc0bf69df1282 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -381,6 +381,40 @@ columns: plt.close('all') +.. _visualization.errorbars: + +Plotting With Error Bars +~~~~~~~~~~~~~~~~~~~~~~~~ +Plotting with error bars is now supported in the ``.plot`` method of ``DataFrame`` and ``Series`` objects. + +x and y errorbars are supported and be supplied using the ``xerr`` and ``yerr`` keyword arguments to ``.plot()`` The error values can be specified using a variety of formats. + +- As a ``DataFrame`` or ``dict`` of errors with column names matching the ``columns`` attribute of the plotting ``DataFrame`` or matching the ``name`` attribute of the ``Series`` +- As a ``str`` indicating which of the columns of plotting ``DataFrame`` contain the error values +- As raw values (``list``, ``tuple``, or ``np.ndarray``). Must be the same length as the plotting ``DataFrame``/``Series`` + +Asymmetrical error bars are also supported, however raw error values must be provided in this case. For a ``M`` length ``Series``, a ``Mx2`` array should be provided indicating lower and upper (or left and right) errors. For a ``MxN`` ``DataFrame``, asymmetrical errors should be in a ``Mx2xN`` array. + +Here is an example of one way to easily plot group means with standard deviations from the raw data. + +.. ipython:: python + + # Generate the data + ix3 = pd.MultiIndex.from_arrays([['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b'], ['foo', 'foo', 'bar', 'bar', 'foo', 'foo', 'bar', 'bar']], names=['letter', 'word']) + df3 = pd.DataFrame({'data1': [3, 2, 4, 3, 2, 4, 3, 2], 'data2': [6, 5, 7, 5, 4, 5, 6, 5]}, index=ix3) + + # Group by index labels and take the means and standard deviations for each group + gp3 = df3.groupby(level=('letter', 'word')) + means = gp3.mean() + errors = gp3.std() + means + errors + + # Plot + fig, ax = plt.subplots() + @savefig errorbar_example.png + means.plot(yerr=errors, ax=ax, kind='bar') + .. _visualization.scatter_matrix: Scatter plot matrix diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 30ba5cd5a70fe..2752d12765fad 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -360,6 +360,35 @@ def test_dup_datetime_index_plot(self): s = Series(values, index=index) _check_plot_works(s.plot) + @slow + def test_errorbar_plot(self): + + s = Series(np.arange(10)) + s_err = np.random.randn(10) + + # test line and bar plots + kinds = ['line', 'bar'] + for kind in kinds: + _check_plot_works(s.plot, yerr=Series(s_err), kind=kind) + _check_plot_works(s.plot, yerr=s_err, kind=kind) + _check_plot_works(s.plot, yerr=s_err.tolist(), kind=kind) + + _check_plot_works(s.plot, xerr=s_err) + + # test time series plotting + ix = date_range('1/1/2000', '1/1/2001', freq='M') + ts = Series(np.arange(12), index=ix) + ts_err = Series(np.random.randn(12), index=ix) + + _check_plot_works(ts.plot, yerr=ts_err) + + # check incorrect lengths and types + with tm.assertRaises(ValueError): + s.plot(yerr=np.arange(11)) + + s_err = ['zzz']*10 + with tm.assertRaises(TypeError): + s.plot(yerr=s_err) @tm.mplskip class TestDataFramePlots(tm.TestCase): @@ -1015,6 +1044,104 @@ def test_allow_cmap(self): df.plot(kind='hexbin', x='A', y='B', cmap='YlGn', colormap='BuGn') + def test_errorbar_plot(self): + + d = {'x': np.arange(12), 'y': np.arange(12, 0, -1)} + df = DataFrame(d) + d_err = {'x': np.ones(12)*0.2, 'y': np.ones(12)*0.4} + df_err = DataFrame(d_err) + + # check line plots + _check_plot_works(df.plot, yerr=df_err, logy=True) + _check_plot_works(df.plot, yerr=df_err, logx=True, logy=True) + + kinds = ['line', 'bar', 'barh'] + for kind in kinds: + _check_plot_works(df.plot, yerr=df_err['x'], kind=kind) + _check_plot_works(df.plot, yerr=d_err, kind=kind) + _check_plot_works(df.plot, yerr=df_err, xerr=df_err, kind=kind) + _check_plot_works(df.plot, yerr=df_err['x'], xerr=df_err['x'], kind=kind) + _check_plot_works(df.plot, yerr=df_err, xerr=df_err, subplots=True, kind=kind) + + _check_plot_works((df+1).plot, yerr=df_err, xerr=df_err, kind='bar', log=True) + + # yerr is raw error values + _check_plot_works(df['y'].plot, yerr=np.ones(12)*0.4) + _check_plot_works(df.plot, yerr=np.ones((2, 12))*0.4) + + # yerr is column name + df['yerr'] = np.ones(12)*0.2 + _check_plot_works(df.plot, y='y', x='x', yerr='yerr') + + with tm.assertRaises(ValueError): + df.plot(yerr=np.random.randn(11)) + + df_err = DataFrame({'x': ['zzz']*12, 'y': ['zzz']*12}) + with tm.assertRaises(TypeError): + df.plot(yerr=df_err) + + @slow + def test_errorbar_with_integer_column_names(self): + # test with integer column names + df = DataFrame(np.random.randn(10, 2)) + df_err = DataFrame(np.random.randn(10, 2)) + _check_plot_works(df.plot, yerr=df_err) + _check_plot_works(df.plot, y=0, yerr=1) + + @slow + def test_errorbar_with_partial_columns(self): + df = DataFrame(np.random.randn(10, 3)) + df_err = DataFrame(np.random.randn(10, 2), columns=[0, 2]) + kinds = ['line', 'bar'] + for kind in kinds: + _check_plot_works(df.plot, yerr=df_err, kind=kind) + + ix = date_range('1/1/2000', periods=10, freq='M') + df.set_index(ix, inplace=True) + df_err.set_index(ix, inplace=True) + _check_plot_works(df.plot, yerr=df_err, kind='line') + + @slow + def test_errorbar_timeseries(self): + + d = {'x': np.arange(12), 'y': np.arange(12, 0, -1)} + d_err = {'x': np.ones(12)*0.2, 'y': np.ones(12)*0.4} + + # check time-series plots + ix = date_range('1/1/2000', '1/1/2001', freq='M') + tdf = DataFrame(d, index=ix) + tdf_err = DataFrame(d_err, index=ix) + + kinds = ['line', 'bar', 'barh'] + for kind in kinds: + _check_plot_works(tdf.plot, yerr=tdf_err, kind=kind) + _check_plot_works(tdf.plot, yerr=d_err, kind=kind) + _check_plot_works(tdf.plot, y='y', kind=kind) + _check_plot_works(tdf.plot, y='y', yerr='x', kind=kind) + _check_plot_works(tdf.plot, yerr=tdf_err, kind=kind) + _check_plot_works(tdf.plot, kind=kind, subplots=True) + + + def test_errorbar_asymmetrical(self): + + np.random.seed(0) + err = np.random.rand(3, 2, 5) + + data = np.random.randn(5, 3) + df = DataFrame(data) + + ax = df.plot(yerr=err, xerr=err/2) + + self.assertEqual(ax.lines[7].get_ydata()[0], data[0,1]-err[1,0,0]) + self.assertEqual(ax.lines[8].get_ydata()[0], data[0,1]+err[1,1,0]) + + self.assertEqual(ax.lines[5].get_xdata()[0], -err[1,0,0]/2) + self.assertEqual(ax.lines[6].get_xdata()[0], err[1,1,0]/2) + + with tm.assertRaises(ValueError): + df.plot(yerr=err.T) + + tm.close() @tm.mplskip class TestDataFrameGroupByPlots(tm.TestCase): diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 7038284b6c2a0..507e0127a5062 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -831,6 +831,11 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, self.fig = fig self.axes = None + # parse errorbar input if given + for err_dim in 'xy': + if err_dim+'err' in kwds: + kwds[err_dim+'err'] = self._parse_errorbars(error_dim=err_dim, **kwds) + if not isinstance(secondary_y, (bool, tuple, list, np.ndarray)): secondary_y = [secondary_y] self.secondary_y = secondary_y @@ -971,6 +976,11 @@ def _setup_subplots(self): axes = [ax] + if self.logx: + [a.set_xscale('log') for a in axes] + if self.logy: + [a.set_yscale('log') for a in axes] + self.fig = fig self.axes = axes @@ -1090,14 +1100,16 @@ def _is_datetype(self): 'time')) def _get_plot_function(self): - if self.logy: - plotf = self.plt.Axes.semilogy - elif self.logx: - plotf = self.plt.Axes.semilogx - elif self.loglog: - plotf = self.plt.Axes.loglog - else: + ''' + Returns the matplotlib plotting function (plot or errorbar) based on + the presence of errorbar keywords. + ''' + + if ('xerr' not in self.kwds) and \ + ('yerr' not in self.kwds): plotf = self.plt.Axes.plot + else: + plotf = self.plt.Axes.errorbar return plotf @@ -1180,6 +1192,78 @@ def _get_marked_label(self, label, col_num): else: return label + def _parse_errorbars(self, error_dim='y', **kwds): + ''' + Look for error keyword arguments and return the actual errorbar data + or return the error DataFrame/dict + + Error bars can be specified in several ways: + Series: the user provides a pandas.Series object of the same + length as the data + ndarray: provides a np.ndarray of the same length as the data + DataFrame/dict: error values are paired with keys matching the + key in the plotted DataFrame + str: the name of the column within the plotted DataFrame + ''' + + err_kwd = kwds.pop(error_dim+'err', None) + if err_kwd is None: + return None + + from pandas import DataFrame, Series + + def match_labels(data, err): + err = err.reindex_axis(data.index) + return err + + # key-matched DataFrame + if isinstance(err_kwd, DataFrame): + err = err_kwd + err = match_labels(self.data, err) + + # key-matched dict + elif isinstance(err_kwd, dict): + err = err_kwd + + # Series of error values + elif isinstance(err_kwd, Series): + # broadcast error series across data + err = match_labels(self.data, err_kwd) + err = np.atleast_2d(err) + err = np.tile(err, (self.nseries, 1)) + + # errors are a column in the dataframe + elif isinstance(err_kwd, str): + err = np.atleast_2d(self.data[err_kwd].values) + self.data = self.data[self.data.columns.drop(err_kwd)] + err = np.tile(err, (self.nseries, 1)) + + elif isinstance(err_kwd, (tuple, list, np.ndarray)): + + # raw error values + err = np.atleast_2d(err_kwd) + + err_shape = err.shape + + # asymmetrical error bars + if err.ndim==3: + if (err_shape[0] != self.nseries) or \ + (err_shape[1] != 2) or \ + (err_shape[2] != len(self.data)): + msg = "Asymmetrical error bars should be provided " + \ + "with the shape (%u, 2, %u)" % \ + (self.nseries, len(self.data)) + raise ValueError(msg) + + # broadcast errors to each data series + if len(err)==1: + err = np.tile(err, (self.nseries, 1)) + + else: + msg = "No valid %serr detected" % error_dim + raise ValueError(msg) + + return err class KdePlot(MPLPlot): def __init__(self, data, bw_method=None, ind=None, **kwargs): @@ -1191,7 +1275,7 @@ def _make_plot(self): from scipy.stats import gaussian_kde from scipy import __version__ as spv from distutils.version import LooseVersion - plotf = self._get_plot_function() + plotf = self.plt.Axes.plot colors = self._get_colors() for i, (label, y) in enumerate(self._iter_data()): ax = self._get_ax(i) @@ -1376,8 +1460,9 @@ def _make_plot(self): # this is slightly deceptive if not self.x_compat and self.use_index and self._use_dynamic_x(): data = self._maybe_convert_index(self.data) - self._make_ts_plot(data, **self.kwds) + self._make_ts_plot(data) else: + from pandas.core.frame import DataFrame lines = [] labels = [] x = self._get_xticks(convert_period=True) @@ -1391,6 +1476,16 @@ def _make_plot(self): kwds = self.kwds.copy() self._maybe_add_color(colors, kwds, style, i) + for err_kw in ['xerr', 'yerr']: + # user provided label-matched dataframe of errors + if err_kw in kwds: + if isinstance(kwds[err_kw], (DataFrame, dict)): + if label in kwds[err_kw].keys(): + kwds[err_kw] = kwds[err_kw][label] + else: del kwds[err_kw] + elif kwds[err_kw] is not None: + kwds[err_kw] = kwds[err_kw][i] + label = com.pprint_thing(label) # .encode('utf-8') mask = com.isnull(y) @@ -1399,10 +1494,11 @@ def _make_plot(self): y = np.ma.masked_where(mask, y) kwds['label'] = label - if style is None: - args = (ax, x, y) - else: + # prevent style kwarg from going to errorbar, where it is unsupported + if style is not None and plotf.__name__=='plot': args = (ax, x, y, style) + else: + args = (ax, x, y) newline = plotf(*args, **kwds)[0] lines.append(newline) @@ -1422,6 +1518,8 @@ def _make_plot(self): def _make_ts_plot(self, data, **kwargs): from pandas.tseries.plotting import tsplot + from pandas.core.frame import DataFrame + kwargs = kwargs.copy() colors = self._get_colors() @@ -1430,8 +1528,15 @@ def _make_ts_plot(self, data, **kwargs): labels = [] def _plot(data, col_num, ax, label, style, **kwds): - newlines = tsplot(data, plotf, ax=ax, label=label, - style=style, **kwds) + + if plotf.__name__=='plot': + newlines = tsplot(data, plotf, ax=ax, label=label, + style=style, **kwds) + # errorbar function does not support style argument + elif plotf.__name__=='errorbar': + newlines = tsplot(data, plotf, ax=ax, label=label, + **kwds) + ax.grid(self.grid) lines.append(newlines[0]) @@ -1444,19 +1549,33 @@ def _plot(data, col_num, ax, label, style, **kwds): ax = self._get_ax(0) # self.axes[0] style = self.style or '' label = com.pprint_thing(self.label) - kwds = kwargs.copy() + kwds = self.kwds.copy() self._maybe_add_color(colors, kwds, style, 0) + if 'yerr' in kwds: + kwds['yerr'] = kwds['yerr'][0] + _plot(data, 0, ax, label, self.style, **kwds) + else: for i, col in enumerate(data.columns): label = com.pprint_thing(col) ax = self._get_ax(i) style = self._get_style(i, col) - kwds = kwargs.copy() + kwds = self.kwds.copy() self._maybe_add_color(colors, kwds, style, i) + # key-matched DataFrame of errors + if 'yerr' in kwds: + yerr = kwds['yerr'] + if isinstance(yerr, (DataFrame, dict)): + if col in yerr.keys(): + kwds['yerr'] = yerr[col] + else: del kwds['yerr'] + else: + kwds['yerr'] = yerr[i] + _plot(data[col], i, ax, label, style, **kwds) self._make_legend(lines, labels) @@ -1581,6 +1700,7 @@ def f(ax, x, y, w, start=None, log=self.log, **kwds): def _make_plot(self): import matplotlib as mpl + from pandas import DataFrame, Series # mpl decided to make their version string unicode across all Python # versions for mpl >= 1.3 so we have to call str here for python 2 @@ -1599,10 +1719,25 @@ def _make_plot(self): for i, (label, y) in enumerate(self._iter_data()): ax = self._get_ax(i) - label = com.pprint_thing(label) kwds = self.kwds.copy() kwds['color'] = colors[i % ncolors] + for err_kw in ['xerr', 'yerr']: + if err_kw in kwds: + # user provided label-matched dataframe of errors + if isinstance(kwds[err_kw], (DataFrame, dict)): + if label in kwds[err_kw].keys(): + kwds[err_kw] = kwds[err_kw][label] + else: del kwds[err_kw] + elif kwds[err_kw] is not None: + kwds[err_kw] = kwds[err_kw][i] + + label = com.pprint_thing(label) + + if (('yerr' in kwds) or ('xerr' in kwds)) \ + and (kwds.get('ecolor') is None): + kwds['ecolor'] = mpl.rcParams['xtick.color'] + start = 0 if self.log: start = 1 @@ -1694,6 +1829,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, x : label or position, default None y : label or position, default None Allows plotting of one column versus another + yerr : DataFrame (with matching labels), Series, list-type (tuple, list, + ndarray), or str of column name containing y error values + xerr : similar functionality as yerr, but for x error values subplots : boolean, default False Make separate subplots for each time series sharex : boolean, default True @@ -1807,6 +1945,15 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, label = kwds.pop('label', label) ser = frame[y] ser.index.name = label + + for kw in ['xerr', 'yerr']: + if (kw in kwds) and \ + (isinstance(kwds[kw], str) or com.is_integer(kwds[kw])): + try: + kwds[kw] = frame[kwds[kw]] + except (IndexError, KeyError, TypeError): + pass + return plot_series(ser, label=label, kind=kind, use_index=use_index, rot=rot, xticks=xticks, yticks=yticks, @@ -1876,6 +2023,8 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, ----- See matplotlib documentation online for more on this subject """ + + from pandas import DataFrame kind = _get_standard_kind(kind.lower().strip()) if kind == 'line': klass = LinePlot @@ -2611,6 +2760,7 @@ def _maybe_convert_date(x): x = conv_func(x) return x + if __name__ == '__main__': # import pandas.rpy.common as com # sales = com.load_data('sanfrancisco.home.sales', package='nutshell')