From c76d464353f2f34d75c1a24221b7bdb9ec318b71 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Sun, 13 Apr 2014 07:37:30 +0900 Subject: [PATCH] CLN: simplify series plotting --- pandas/tests/test_graphics.py | 25 +++----- pandas/tools/plotting.py | 114 ++++++++++++++-------------------- 2 files changed, 57 insertions(+), 82 deletions(-) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index b1faf3047beea..e81cfd39ba78e 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -55,9 +55,10 @@ def test_plot(self): _check_plot_works(self.ts.plot, style='.', loglog=True) _check_plot_works(self.ts[:10].plot, kind='bar') _check_plot_works(self.iseries.plot) - _check_plot_works(self.series[:5].plot, kind='bar') - _check_plot_works(self.series[:5].plot, kind='line') - _check_plot_works(self.series[:5].plot, kind='barh') + + for kind in plotting._common_kinds: + _check_plot_works(self.series[:5].plot, kind=kind) + _check_plot_works(self.series[:10].plot, kind='barh') _check_plot_works(Series(randn(10)).plot, kind='bar', color='black') @@ -250,25 +251,19 @@ def test_bootstrap_plot(self): def test_invalid_plot_data(self): s = Series(list('abcd')) - kinds = 'line', 'bar', 'barh', 'kde', 'density' - - for kind in kinds: + for kind in plotting._common_kinds: with tm.assertRaises(TypeError): s.plot(kind=kind) @slow def test_valid_object_plot(self): s = Series(lrange(10), dtype=object) - kinds = 'line', 'bar', 'barh', 'kde', 'density' - - for kind in kinds: + for kind in plotting._common_kinds: _check_plot_works(s.plot, kind=kind) def test_partially_invalid_plot_data(self): s = Series(['a', 'b', 1.0, 2]) - kinds = 'line', 'bar', 'barh', 'kde', 'density' - - for kind in kinds: + for kind in plotting._common_kinds: with tm.assertRaises(TypeError): s.plot(kind=kind) @@ -1247,19 +1242,17 @@ def test_unordered_ts(self): assert_array_equal(ydata, np.array([1.0, 2.0, 3.0])) def test_all_invalid_plot_data(self): - kinds = 'line', 'bar', 'barh', 'kde', 'density' df = DataFrame(list('abcd')) - for kind in kinds: + for kind in plotting._common_kinds: with tm.assertRaises(TypeError): df.plot(kind=kind) @slow def test_partially_invalid_plot_data(self): with tm.RNGContext(42): - kinds = 'line', 'bar', 'barh', 'kde', 'density' df = DataFrame(randn(10, 2), dtype=object) df[np.random.rand(df.shape[0]) > 0.5] = 'a' - for kind in kinds: + for kind in plotting._common_kinds: with tm.assertRaises(TypeError): df.plot(kind=kind) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 4d348c37ed927..971aa7848c2fa 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -887,25 +887,31 @@ def _validate_color_args(self): " use one or the other or pass 'style' " "without a color symbol") - def _iter_data(self): - from pandas.core.frame import DataFrame - if isinstance(self.data, (Series, np.ndarray)): - yield self.label, np.asarray(self.data) - elif isinstance(self.data, DataFrame): - df = self.data + def _iter_data(self, data=None, keep_index=False): + if data is None: + data = self.data + from pandas.core.frame import DataFrame + if isinstance(data, (Series, np.ndarray)): + if keep_index is True: + yield self.label, data + else: + yield self.label, np.asarray(data) + elif isinstance(data, DataFrame): if self.sort_columns: - columns = com._try_sort(df.columns) + columns = com._try_sort(data.columns) else: - columns = df.columns + columns = data.columns for col in columns: # # is this right? # empty = df[col].count() == 0 # values = df[col].values if not empty else np.zeros(len(df)) - values = df[col].values - yield col, values + if keep_index is True: + yield col, data[col] + else: + yield col, data[col].values @property def nseries(self): @@ -1593,38 +1599,26 @@ def _plot(data, col_num, ax, label, style, **kwds): self._add_legend_handle(newlines[0], label, index=col_num) - if isinstance(data, Series): - ax = self._get_ax(0) # self.axes[0] - style = self.style or '' - label = com.pprint_thing(self.label) + it = self._iter_data(data=data, keep_index=True) + for i, (label, y) in enumerate(it): + ax = self._get_ax(i) + style = self._get_style(i, label) kwds = self.kwds.copy() - self._maybe_add_color(colors, kwds, style, 0) - - if 'yerr' in kwds: - kwds['yerr'] = kwds['yerr'][0] - _plot(data, 0, ax, label, self.style, **kwds) - - else: - for i, col in enumerate(data.columns): - label = com.pprint_thing(col) - ax = self._get_ax(i) - style = self._get_style(i, col) - kwds = self.kwds.copy() - - self._maybe_add_color(colors, kwds, style, i) + self._maybe_add_color(colors, kwds, style, i) - # key-matched DataFrame of errors - if 'yerr' in kwds: - yerr = kwds['yerr'] - if isinstance(yerr, (DataFrame, dict)): - if col in yerr.keys(): - kwds['yerr'] = yerr[col] - else: del kwds['yerr'] - else: - kwds['yerr'] = yerr[i] + # key-matched DataFrame of errors + if 'yerr' in kwds: + yerr = kwds['yerr'] + if isinstance(yerr, (DataFrame, dict)): + if label in yerr.keys(): + kwds['yerr'] = yerr[label] + else: del kwds['yerr'] + else: + kwds['yerr'] = yerr[i] - _plot(data[col], i, ax, label, style, **kwds) + label = com.pprint_thing(label) + _plot(y, i, ax, label, style, **kwds) def _maybe_convert_index(self, data): # tsplot converts automatically, but don't want to convert index @@ -1828,6 +1822,16 @@ class BoxPlot(MPLPlot): class HistPlot(MPLPlot): pass +# kinds supported by both dataframe and series +_common_kinds = ['line', 'bar', 'barh', 'kde', 'density'] +# kinds supported by dataframe +_dataframe_kinds = ['scatter', 'hexbin'] +_all_kinds = _common_kinds + _dataframe_kinds + +_plot_klass = {'line': LinePlot, 'bar': BarPlot, 'barh': BarPlot, + 'kde': KdePlot, + 'scatter': ScatterPlot, 'hexbin': HexBinPlot} + def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, sharey=False, use_index=True, figsize=None, grid=None, @@ -1921,21 +1925,14 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, is a function of one argument that reduces all the values in a bin to a single number (e.g. `mean`, `max`, `sum`, `std`). """ + kind = _get_standard_kind(kind.lower().strip()) - if kind == 'line': - klass = LinePlot - elif kind in ('bar', 'barh'): - klass = BarPlot - elif kind == 'kde': - klass = KdePlot - elif kind == 'scatter': - klass = ScatterPlot - elif kind == 'hexbin': - klass = HexBinPlot + if kind in _dataframe_kinds or kind in _common_kinds: + klass = _plot_klass[kind] else: raise ValueError('Invalid chart type given %s' % kind) - if kind == 'scatter': + if kind in _dataframe_kinds: plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, rot=rot,legend=legend, ax=ax, style=style, fontsize=fontsize, use_index=use_index, sharex=sharex, @@ -1944,16 +1941,6 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, figsize=figsize, logx=logx, logy=logy, sort_columns=sort_columns, secondary_y=secondary_y, **kwds) - elif kind == 'hexbin': - C = kwds.pop('C', None) # remove from kwargs so we can set default - plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, - rot=rot,legend=legend, ax=ax, style=style, - fontsize=fontsize, use_index=use_index, sharex=sharex, - sharey=sharey, xticks=xticks, yticks=yticks, - xlim=xlim, ylim=ylim, title=title, grid=grid, - figsize=figsize, logx=logx, logy=logy, - sort_columns=sort_columns, secondary_y=secondary_y, - C=C, **kwds) else: if x is not None: if com.is_integer(x) and not frame.columns.holds_integer(): @@ -2051,14 +2038,9 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, See matplotlib documentation online for more on this subject """ - from pandas import DataFrame kind = _get_standard_kind(kind.lower().strip()) - if kind == 'line': - klass = LinePlot - elif kind in ('bar', 'barh'): - klass = BarPlot - elif kind == 'kde': - klass = KdePlot + if kind in _common_kinds: + klass = _plot_klass[kind] else: raise ValueError('Invalid chart type given %s' % kind)