Skip to content

(WIP)ENH: plot now supports cycler #12547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/requirements-2.7_SLOW.run
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ python-dateutil
pytz
numpy=1.8.2
matplotlib=1.3.1
cycler
scipy
patsy
statsmodels
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-3.4_SLOW.run
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ scipy
numexpr=2.4.4
pytables
matplotlib
cycler
lxml
sqlalchemy
bottleneck
Expand Down
76 changes: 73 additions & 3 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,40 @@ def test_bar_bottom_left(self):
result = [p.get_x() for p in ax.patches]
self.assertEqual(result, [1] * 5)

@slow
def test_bar_hatches(self):
df = DataFrame(rand(4, 3))

ax = df.plot.bar()
result = [p._hatch for p in ax.patches]
self.assertEqual(result, [None] * 12)
tm.close()

ax = df.plot.bar(hatch='*')
result = [p._hatch for p in ax.patches]
self.assertEqual(result, ['*'] * 12)
tm.close()

from cycler import cycler
ax = df.plot.bar(hatch=cycler('hatch', ['*', '+', '//']))
result = [p._hatch for p in ax.patches[:4]]
self.assertEqual(result, ['*'] * 4)
result = [p._hatch for p in ax.patches[4:8]]
self.assertEqual(result, ['+'] * 4)
result = [p._hatch for p in ax.patches[8:]]
self.assertEqual(result, ['//'] * 4)
tm.close()

# length mismatch, loops implicitly
ax = df.plot.bar(hatch=cycler('hatch', ['*', '+']))
result = [p._hatch for p in ax.patches[:4]]
self.assertEqual(result, ['*'] * 4)
result = [p._hatch for p in ax.patches[4:8]]
self.assertEqual(result, ['+'] * 4)
result = [p._hatch for p in ax.patches[8:]]
self.assertEqual(result, ['*'] * 4)
tm.close()

@slow
def test_bar_nan(self):
df = DataFrame({'A': [10, np.nan, 20],
Expand Down Expand Up @@ -2953,6 +2987,10 @@ def test_line_colors_and_styles_subplots(self):
self._check_colors(ax.get_lines(), linecolors=[c])
tm.close()

def _get_polycollection(self, ax):
from matplotlib.collections import PolyCollection
return [o for o in ax.get_children() if isinstance(o, PolyCollection)]

@slow
def test_area_colors(self):
from matplotlib import cm
Expand All @@ -2963,7 +3001,7 @@ def test_area_colors(self):

ax = df.plot.area(color=custom_colors)
self._check_colors(ax.get_lines(), linecolors=custom_colors)
poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)]
poly = self._get_polycollection(ax)
self._check_colors(poly, facecolors=custom_colors)

handles, labels = ax.get_legend_handles_labels()
Expand All @@ -2977,7 +3015,7 @@ def test_area_colors(self):
ax = df.plot.area(colormap='jet')
jet_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
self._check_colors(ax.get_lines(), linecolors=jet_colors)
poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)]
poly = self._get_polycollection(ax)
self._check_colors(poly, facecolors=jet_colors)

handles, labels = ax.get_legend_handles_labels()
Expand All @@ -2990,7 +3028,7 @@ def test_area_colors(self):
# When stacked=False, alpha is set to 0.5
ax = df.plot.area(colormap=cm.jet, stacked=False)
self._check_colors(ax.get_lines(), linecolors=jet_colors)
poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)]
poly = self._get_polycollection(ax)
jet_with_alpha = [(c[0], c[1], c[2], 0.5) for c in jet_colors]
self._check_colors(poly, facecolors=jet_with_alpha)

Expand All @@ -3000,6 +3038,38 @@ def test_area_colors(self):
for h in handles:
self.assertEqual(h.get_alpha(), 0.5)

@slow
def test_area_hatches(self):
df = DataFrame(rand(4, 3))

ax = df.plot.area(stacked=False)
result = [x._hatch for x in self._get_polycollection(ax)]
self.assertEqual(result, [None] * 3)
tm.close()

ax = df.plot.area(hatch='*', stacked=False)
result = [x._hatch for x in self._get_polycollection(ax)]
self.assertEqual(result, ['*'] * 3)
tm.close()

from cycler import cycler
ax = df.plot.area(hatch=cycler('hatch', ['*', '+', '//']),
stacked=False)
poly = self._get_polycollection(ax)
self.assertEqual(poly[0]._hatch, '*')
self.assertEqual(poly[1]._hatch, '+')
self.assertEqual(poly[2]._hatch, '//')
tm.close()

# length mismatch, loops implicitly
ax = df.plot.area(hatch=cycler('hatch', ['*', '+']),
stacked=False)
poly = self._get_polycollection(ax)
self.assertEqual(poly[0]._hatch, '*')
self.assertEqual(poly[1]._hatch, '+')
self.assertEqual(poly[2]._hatch, '*')
tm.close()

@slow
def test_hist_colors(self):
default_colors = self._maybe_unpack_cycler(self.plt.rcParams)
Expand Down
85 changes: 81 additions & 4 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,14 @@ def _mpl_ge_1_5_0():
except ImportError:
return False

if _mpl_ge_1_5_0():
try:
_CYCLER_INSTALLED = True
# Compat with mp 1.5, which uses cycler.
import cycler
colors = mpl_stylesheet.pop('axes.color_cycle')
mpl_stylesheet['axes.prop_cycle'] = cycler.cycler('color', colors)
except ImportError:
_CYCLER_INSTALLED = False


def _get_standard_kind(kind):
Expand Down Expand Up @@ -884,7 +887,6 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,
self.by = by

self.kind = kind

self.sort_columns = sort_columns

self.subplots = subplots
Expand Down Expand Up @@ -959,7 +961,9 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,

self.table = table

self.kwds = kwds
# init and validatecycler keyword
# ToDo: _validate_color_arg may change kwds to list
self.kwds = self._init_cyclic_option(kwds)

self._validate_color_args()

Expand Down Expand Up @@ -993,6 +997,60 @@ def _validate_color_args(self):
" use one or the other or pass 'style' "
"without a color symbol")

def _init_cyclic_option(self, kwds):
"""
Convert passed kwds to cycler instance
"""
if not _CYCLER_INSTALLED:
return kwds

option = {}
for key, value in compat.iteritems(kwds):
if isinstance(value, cycler.Cycler):
cycler_keys = value.keys
if len(cycler_keys) > 1:
msg = ("cycler should only contain "
"passed keyword '{0}': {1}")
raise ValueError(msg.format(key, value))
if key not in cycler_keys:
msg = ("cycler must contain "
"passed keyword '{0}': {1}")
raise ValueError(msg.format(key, value))

elif isinstance(value, list):
# instanciate cycler
# to do: check mpl kwds which should handle list as it is
cycler_value = cycler.cycler(key, value)
value = cycler_value

option[key] = value
return option

def _get_cyclic_option(self, kwds, num):
"""
Get num-th element of cycler contained in passed kwds.
"""
if not _CYCLER_INSTALLED:
return kwds

option = {}
for key, value in compat.iteritems(kwds):
if isinstance(value, cycler.Cycler):
# cycler() will implicitly loop, cycler will not
# cycler 0.10 or later is required
for i, v in enumerate(value()):
if i == num:
try:
option[key] = v[key]
except KeyError:
msg = ("cycler doesn't contain required "
"key '{0}': {1}")
raise ValueError(msg.format(key, value))
break
else:
option[key] = value
return option

def _iter_data(self, data=None, keep_index=False, fillna=None):
if data is None:
data = self.data
Expand All @@ -1013,6 +1071,9 @@ def _iter_data(self, data=None, keep_index=False, fillna=None):

@property
def nseries(self):
"""
Number of columns to be plotted. If data is a Series, return 1.
"""
if self.data.ndim == 1:
return 1
else:
Expand Down Expand Up @@ -1161,6 +1222,7 @@ def _post_plot_logic_common(self, ax, data):
self._apply_axis_properties(ax.xaxis, rot=self.rot,
fontsize=self.fontsize)
self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)

elif self.orientation == 'horizontal':
if self._need_to_set_index:
yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
Expand Down Expand Up @@ -1696,8 +1758,10 @@ def _make_plot(self):
colors = self._get_colors()
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)

kwds = self.kwds.copy()
style, kwds = self._apply_style_colors(colors, kwds, i, label)
kwds = self._get_cyclic_option(kwds, i)

errors = self._get_errorbars(label=label, index=i)
kwds = dict(kwds, **errors)
Expand Down Expand Up @@ -1829,13 +1893,20 @@ def __init__(self, data, **kwargs):
if self.logy or self.loglog:
raise ValueError("Log-y scales are not supported in area plot")

# kwds should not be passed to line
_fill_only_kwds = ['hatch']

@classmethod
def _plot(cls, ax, x, y, style=None, column_num=None,
stacking_id=None, is_errorbar=False, **kwds):
if column_num == 0:
cls._initialize_stacker(ax, stacking_id, len(y))
y_values = cls._get_stacked_values(ax, stacking_id, y, kwds['label'])
lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds)

line_kwds = kwds.copy()
for attr in cls._fill_only_kwds:
line_kwds.pop(attr, None)
lines = MPLPlot._plot(ax, x, y_values, style=style, **line_kwds)

# get data from the line to get coordinates for fill_between
xdata, y_values = lines[0].get_data(orig=False)
Expand Down Expand Up @@ -1939,6 +2010,8 @@ def _make_plot(self):
kwds = self.kwds.copy()
kwds['color'] = colors[i % ncolors]

kwds = self._get_cyclic_option(kwds, i)

errors = self._get_errorbars(label=label, index=i)
kwds = dict(kwds, **errors)

Expand Down Expand Up @@ -2064,6 +2137,7 @@ def _make_plot(self):
ax = self._get_ax(i)

kwds = self.kwds.copy()
kwds = self._get_cyclic_option(kwds, i)

label = pprint_thing(label)
kwds['label'] = label
Expand Down Expand Up @@ -2180,6 +2254,7 @@ def _make_plot(self):
ax.set_ylabel(label)

kwds = self.kwds.copy()
kwds = self._get_cyclic_option(kwds, i)

def blank_labeler(label, value):
if value == 0:
Expand Down Expand Up @@ -2320,6 +2395,7 @@ def _make_plot(self):
for i, (label, y) in enumerate(self._iter_data()):
ax = self._get_ax(i)
kwds = self.kwds.copy()
kwds = self._get_cyclic_option(kwds, i)

ret, bp = self._plot(ax, y, column_num=i,
return_type=self.return_type, **kwds)
Expand All @@ -2332,6 +2408,7 @@ def _make_plot(self):
y = self.data.values.T
ax = self._get_ax(0)
kwds = self.kwds.copy()
kwds = self._get_cyclic_option(kwds, 0)

ret, bp = self._plot(ax, y, column_num=0,
return_type=self.return_type, **kwds)
Expand Down