From 50803420f081c4cf8292d7ba603fbeff3665c86b Mon Sep 17 00:00:00 2001 From: sinhrks Date: Tue, 15 Sep 2015 20:16:39 +0900 Subject: [PATCH] ENH: plot now supports cyclic hatch --- ci/requirements-2.7_SLOW.run | 1 + ci/requirements-3.4_SLOW.run | 1 + pandas/tests/test_graphics.py | 76 +++++++++++++++++++++++++++++-- pandas/tools/plotting.py | 85 +++++++++++++++++++++++++++++++++-- 4 files changed, 156 insertions(+), 7 deletions(-) diff --git a/ci/requirements-2.7_SLOW.run b/ci/requirements-2.7_SLOW.run index f02a7cb8a309a..cc70e96a00365 100644 --- a/ci/requirements-2.7_SLOW.run +++ b/ci/requirements-2.7_SLOW.run @@ -2,6 +2,7 @@ python-dateutil pytz numpy=1.8.2 matplotlib=1.3.1 +cycler scipy patsy statsmodels diff --git a/ci/requirements-3.4_SLOW.run b/ci/requirements-3.4_SLOW.run index f9f226e3f1465..a53336ab85b15 100644 --- a/ci/requirements-3.4_SLOW.run +++ b/ci/requirements-3.4_SLOW.run @@ -12,6 +12,7 @@ scipy numexpr=2.4.4 pytables matplotlib +cycler lxml sqlalchemy bottleneck diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index bd19a83ce2b64..a8045e3e31291 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -2096,6 +2096,40 @@ def test_bar_bottom_left(self): result = [p.get_x() for p in ax.patches] self.assertEqual(result, [1] * 5) + @slow + def test_bar_hatches(self): + df = DataFrame(rand(4, 3)) + + ax = df.plot.bar() + result = [p._hatch for p in ax.patches] + self.assertEqual(result, [None] * 12) + tm.close() + + ax = df.plot.bar(hatch='*') + result = [p._hatch for p in ax.patches] + self.assertEqual(result, ['*'] * 12) + tm.close() + + from cycler import cycler + ax = df.plot.bar(hatch=cycler('hatch', ['*', '+', '//'])) + result = [p._hatch for p in ax.patches[:4]] + self.assertEqual(result, ['*'] * 4) + result = [p._hatch for p in ax.patches[4:8]] + self.assertEqual(result, ['+'] * 4) + result = [p._hatch for p in ax.patches[8:]] + self.assertEqual(result, ['//'] * 4) + tm.close() + + # length mismatch, loops implicitly + ax = df.plot.bar(hatch=cycler('hatch', ['*', '+'])) + result = [p._hatch for p in ax.patches[:4]] + self.assertEqual(result, ['*'] * 4) + result = [p._hatch for p in ax.patches[4:8]] + self.assertEqual(result, ['+'] * 4) + result = [p._hatch for p in ax.patches[8:]] + self.assertEqual(result, ['*'] * 4) + tm.close() + @slow def test_bar_nan(self): df = DataFrame({'A': [10, np.nan, 20], @@ -2953,6 +2987,10 @@ def test_line_colors_and_styles_subplots(self): self._check_colors(ax.get_lines(), linecolors=[c]) tm.close() + def _get_polycollection(self, ax): + from matplotlib.collections import PolyCollection + return [o for o in ax.get_children() if isinstance(o, PolyCollection)] + @slow def test_area_colors(self): from matplotlib import cm @@ -2963,7 +3001,7 @@ def test_area_colors(self): ax = df.plot.area(color=custom_colors) self._check_colors(ax.get_lines(), linecolors=custom_colors) - poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)] + poly = self._get_polycollection(ax) self._check_colors(poly, facecolors=custom_colors) handles, labels = ax.get_legend_handles_labels() @@ -2977,7 +3015,7 @@ def test_area_colors(self): ax = df.plot.area(colormap='jet') jet_colors = lmap(cm.jet, np.linspace(0, 1, len(df))) self._check_colors(ax.get_lines(), linecolors=jet_colors) - poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)] + poly = self._get_polycollection(ax) self._check_colors(poly, facecolors=jet_colors) handles, labels = ax.get_legend_handles_labels() @@ -2990,7 +3028,7 @@ def test_area_colors(self): # When stacked=False, alpha is set to 0.5 ax = df.plot.area(colormap=cm.jet, stacked=False) self._check_colors(ax.get_lines(), linecolors=jet_colors) - poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)] + poly = self._get_polycollection(ax) jet_with_alpha = [(c[0], c[1], c[2], 0.5) for c in jet_colors] self._check_colors(poly, facecolors=jet_with_alpha) @@ -3000,6 +3038,38 @@ def test_area_colors(self): for h in handles: self.assertEqual(h.get_alpha(), 0.5) + @slow + def test_area_hatches(self): + df = DataFrame(rand(4, 3)) + + ax = df.plot.area(stacked=False) + result = [x._hatch for x in self._get_polycollection(ax)] + self.assertEqual(result, [None] * 3) + tm.close() + + ax = df.plot.area(hatch='*', stacked=False) + result = [x._hatch for x in self._get_polycollection(ax)] + self.assertEqual(result, ['*'] * 3) + tm.close() + + from cycler import cycler + ax = df.plot.area(hatch=cycler('hatch', ['*', '+', '//']), + stacked=False) + poly = self._get_polycollection(ax) + self.assertEqual(poly[0]._hatch, '*') + self.assertEqual(poly[1]._hatch, '+') + self.assertEqual(poly[2]._hatch, '//') + tm.close() + + # length mismatch, loops implicitly + ax = df.plot.area(hatch=cycler('hatch', ['*', '+']), + stacked=False) + poly = self._get_polycollection(ax) + self.assertEqual(poly[0]._hatch, '*') + self.assertEqual(poly[1]._hatch, '+') + self.assertEqual(poly[2]._hatch, '*') + tm.close() + @slow def test_hist_colors(self): default_colors = self._maybe_unpack_cycler(self.plt.rcParams) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index b6c1926c1e7fc..bf30a46fd74e0 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -134,11 +134,14 @@ def _mpl_ge_1_5_0(): except ImportError: return False -if _mpl_ge_1_5_0(): +try: + _CYCLER_INSTALLED = True # Compat with mp 1.5, which uses cycler. import cycler colors = mpl_stylesheet.pop('axes.color_cycle') mpl_stylesheet['axes.prop_cycle'] = cycler.cycler('color', colors) +except ImportError: + _CYCLER_INSTALLED = False def _get_standard_kind(kind): @@ -884,7 +887,6 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None, self.by = by self.kind = kind - self.sort_columns = sort_columns self.subplots = subplots @@ -959,7 +961,9 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None, self.table = table - self.kwds = kwds + # init and validatecycler keyword + # ToDo: _validate_color_arg may change kwds to list + self.kwds = self._init_cyclic_option(kwds) self._validate_color_args() @@ -993,6 +997,60 @@ def _validate_color_args(self): " use one or the other or pass 'style' " "without a color symbol") + def _init_cyclic_option(self, kwds): + """ + Convert passed kwds to cycler instance + """ + if not _CYCLER_INSTALLED: + return kwds + + option = {} + for key, value in compat.iteritems(kwds): + if isinstance(value, cycler.Cycler): + cycler_keys = value.keys + if len(cycler_keys) > 1: + msg = ("cycler should only contain " + "passed keyword '{0}': {1}") + raise ValueError(msg.format(key, value)) + if key not in cycler_keys: + msg = ("cycler must contain " + "passed keyword '{0}': {1}") + raise ValueError(msg.format(key, value)) + + elif isinstance(value, list): + # instanciate cycler + # to do: check mpl kwds which should handle list as it is + cycler_value = cycler.cycler(key, value) + value = cycler_value + + option[key] = value + return option + + def _get_cyclic_option(self, kwds, num): + """ + Get num-th element of cycler contained in passed kwds. + """ + if not _CYCLER_INSTALLED: + return kwds + + option = {} + for key, value in compat.iteritems(kwds): + if isinstance(value, cycler.Cycler): + # cycler() will implicitly loop, cycler will not + # cycler 0.10 or later is required + for i, v in enumerate(value()): + if i == num: + try: + option[key] = v[key] + except KeyError: + msg = ("cycler doesn't contain required " + "key '{0}': {1}") + raise ValueError(msg.format(key, value)) + break + else: + option[key] = value + return option + def _iter_data(self, data=None, keep_index=False, fillna=None): if data is None: data = self.data @@ -1013,6 +1071,9 @@ def _iter_data(self, data=None, keep_index=False, fillna=None): @property def nseries(self): + """ + Number of columns to be plotted. If data is a Series, return 1. + """ if self.data.ndim == 1: return 1 else: @@ -1161,6 +1222,7 @@ def _post_plot_logic_common(self, ax, data): self._apply_axis_properties(ax.xaxis, rot=self.rot, fontsize=self.fontsize) self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize) + elif self.orientation == 'horizontal': if self._need_to_set_index: yticklabels = [labels.get(y, '') for y in ax.get_yticks()] @@ -1696,8 +1758,10 @@ def _make_plot(self): colors = self._get_colors() for i, (label, y) in enumerate(it): ax = self._get_ax(i) + kwds = self.kwds.copy() style, kwds = self._apply_style_colors(colors, kwds, i, label) + kwds = self._get_cyclic_option(kwds, i) errors = self._get_errorbars(label=label, index=i) kwds = dict(kwds, **errors) @@ -1829,13 +1893,20 @@ def __init__(self, data, **kwargs): if self.logy or self.loglog: raise ValueError("Log-y scales are not supported in area plot") + # kwds should not be passed to line + _fill_only_kwds = ['hatch'] + @classmethod def _plot(cls, ax, x, y, style=None, column_num=None, stacking_id=None, is_errorbar=False, **kwds): if column_num == 0: cls._initialize_stacker(ax, stacking_id, len(y)) y_values = cls._get_stacked_values(ax, stacking_id, y, kwds['label']) - lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds) + + line_kwds = kwds.copy() + for attr in cls._fill_only_kwds: + line_kwds.pop(attr, None) + lines = MPLPlot._plot(ax, x, y_values, style=style, **line_kwds) # get data from the line to get coordinates for fill_between xdata, y_values = lines[0].get_data(orig=False) @@ -1939,6 +2010,8 @@ def _make_plot(self): kwds = self.kwds.copy() kwds['color'] = colors[i % ncolors] + kwds = self._get_cyclic_option(kwds, i) + errors = self._get_errorbars(label=label, index=i) kwds = dict(kwds, **errors) @@ -2064,6 +2137,7 @@ def _make_plot(self): ax = self._get_ax(i) kwds = self.kwds.copy() + kwds = self._get_cyclic_option(kwds, i) label = pprint_thing(label) kwds['label'] = label @@ -2180,6 +2254,7 @@ def _make_plot(self): ax.set_ylabel(label) kwds = self.kwds.copy() + kwds = self._get_cyclic_option(kwds, i) def blank_labeler(label, value): if value == 0: @@ -2320,6 +2395,7 @@ def _make_plot(self): for i, (label, y) in enumerate(self._iter_data()): ax = self._get_ax(i) kwds = self.kwds.copy() + kwds = self._get_cyclic_option(kwds, i) ret, bp = self._plot(ax, y, column_num=i, return_type=self.return_type, **kwds) @@ -2332,6 +2408,7 @@ def _make_plot(self): y = self.data.values.T ax = self._get_ax(0) kwds = self.kwds.copy() + kwds = self._get_cyclic_option(kwds, 0) ret, bp = self._plot(ax, y, column_num=0, return_type=self.return_type, **kwds)