Skip to content

Commit 5080342

Browse files
committed
ENH: plot now supports cyclic hatch
1 parent d38ee27 commit 5080342

File tree

4 files changed

+156
-7
lines changed

4 files changed

+156
-7
lines changed

ci/requirements-2.7_SLOW.run

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ python-dateutil
22
pytz
33
numpy=1.8.2
44
matplotlib=1.3.1
5+
cycler
56
scipy
67
patsy
78
statsmodels

ci/requirements-3.4_SLOW.run

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ scipy
1212
numexpr=2.4.4
1313
pytables
1414
matplotlib
15+
cycler
1516
lxml
1617
sqlalchemy
1718
bottleneck

pandas/tests/test_graphics.py

+73-3
Original file line numberDiff line numberDiff line change
@@ -2096,6 +2096,40 @@ def test_bar_bottom_left(self):
20962096
result = [p.get_x() for p in ax.patches]
20972097
self.assertEqual(result, [1] * 5)
20982098

2099+
@slow
2100+
def test_bar_hatches(self):
2101+
df = DataFrame(rand(4, 3))
2102+
2103+
ax = df.plot.bar()
2104+
result = [p._hatch for p in ax.patches]
2105+
self.assertEqual(result, [None] * 12)
2106+
tm.close()
2107+
2108+
ax = df.plot.bar(hatch='*')
2109+
result = [p._hatch for p in ax.patches]
2110+
self.assertEqual(result, ['*'] * 12)
2111+
tm.close()
2112+
2113+
from cycler import cycler
2114+
ax = df.plot.bar(hatch=cycler('hatch', ['*', '+', '//']))
2115+
result = [p._hatch for p in ax.patches[:4]]
2116+
self.assertEqual(result, ['*'] * 4)
2117+
result = [p._hatch for p in ax.patches[4:8]]
2118+
self.assertEqual(result, ['+'] * 4)
2119+
result = [p._hatch for p in ax.patches[8:]]
2120+
self.assertEqual(result, ['//'] * 4)
2121+
tm.close()
2122+
2123+
# length mismatch, loops implicitly
2124+
ax = df.plot.bar(hatch=cycler('hatch', ['*', '+']))
2125+
result = [p._hatch for p in ax.patches[:4]]
2126+
self.assertEqual(result, ['*'] * 4)
2127+
result = [p._hatch for p in ax.patches[4:8]]
2128+
self.assertEqual(result, ['+'] * 4)
2129+
result = [p._hatch for p in ax.patches[8:]]
2130+
self.assertEqual(result, ['*'] * 4)
2131+
tm.close()
2132+
20992133
@slow
21002134
def test_bar_nan(self):
21012135
df = DataFrame({'A': [10, np.nan, 20],
@@ -2953,6 +2987,10 @@ def test_line_colors_and_styles_subplots(self):
29532987
self._check_colors(ax.get_lines(), linecolors=[c])
29542988
tm.close()
29552989

2990+
def _get_polycollection(self, ax):
2991+
from matplotlib.collections import PolyCollection
2992+
return [o for o in ax.get_children() if isinstance(o, PolyCollection)]
2993+
29562994
@slow
29572995
def test_area_colors(self):
29582996
from matplotlib import cm
@@ -2963,7 +3001,7 @@ def test_area_colors(self):
29633001

29643002
ax = df.plot.area(color=custom_colors)
29653003
self._check_colors(ax.get_lines(), linecolors=custom_colors)
2966-
poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)]
3004+
poly = self._get_polycollection(ax)
29673005
self._check_colors(poly, facecolors=custom_colors)
29683006

29693007
handles, labels = ax.get_legend_handles_labels()
@@ -2977,7 +3015,7 @@ def test_area_colors(self):
29773015
ax = df.plot.area(colormap='jet')
29783016
jet_colors = lmap(cm.jet, np.linspace(0, 1, len(df)))
29793017
self._check_colors(ax.get_lines(), linecolors=jet_colors)
2980-
poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)]
3018+
poly = self._get_polycollection(ax)
29813019
self._check_colors(poly, facecolors=jet_colors)
29823020

29833021
handles, labels = ax.get_legend_handles_labels()
@@ -2990,7 +3028,7 @@ def test_area_colors(self):
29903028
# When stacked=False, alpha is set to 0.5
29913029
ax = df.plot.area(colormap=cm.jet, stacked=False)
29923030
self._check_colors(ax.get_lines(), linecolors=jet_colors)
2993-
poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)]
3031+
poly = self._get_polycollection(ax)
29943032
jet_with_alpha = [(c[0], c[1], c[2], 0.5) for c in jet_colors]
29953033
self._check_colors(poly, facecolors=jet_with_alpha)
29963034

@@ -3000,6 +3038,38 @@ def test_area_colors(self):
30003038
for h in handles:
30013039
self.assertEqual(h.get_alpha(), 0.5)
30023040

3041+
@slow
3042+
def test_area_hatches(self):
3043+
df = DataFrame(rand(4, 3))
3044+
3045+
ax = df.plot.area(stacked=False)
3046+
result = [x._hatch for x in self._get_polycollection(ax)]
3047+
self.assertEqual(result, [None] * 3)
3048+
tm.close()
3049+
3050+
ax = df.plot.area(hatch='*', stacked=False)
3051+
result = [x._hatch for x in self._get_polycollection(ax)]
3052+
self.assertEqual(result, ['*'] * 3)
3053+
tm.close()
3054+
3055+
from cycler import cycler
3056+
ax = df.plot.area(hatch=cycler('hatch', ['*', '+', '//']),
3057+
stacked=False)
3058+
poly = self._get_polycollection(ax)
3059+
self.assertEqual(poly[0]._hatch, '*')
3060+
self.assertEqual(poly[1]._hatch, '+')
3061+
self.assertEqual(poly[2]._hatch, '//')
3062+
tm.close()
3063+
3064+
# length mismatch, loops implicitly
3065+
ax = df.plot.area(hatch=cycler('hatch', ['*', '+']),
3066+
stacked=False)
3067+
poly = self._get_polycollection(ax)
3068+
self.assertEqual(poly[0]._hatch, '*')
3069+
self.assertEqual(poly[1]._hatch, '+')
3070+
self.assertEqual(poly[2]._hatch, '*')
3071+
tm.close()
3072+
30033073
@slow
30043074
def test_hist_colors(self):
30053075
default_colors = self._maybe_unpack_cycler(self.plt.rcParams)

pandas/tools/plotting.py

+81-4
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,14 @@ def _mpl_ge_1_5_0():
134134
except ImportError:
135135
return False
136136

137-
if _mpl_ge_1_5_0():
137+
try:
138+
_CYCLER_INSTALLED = True
138139
# Compat with mp 1.5, which uses cycler.
139140
import cycler
140141
colors = mpl_stylesheet.pop('axes.color_cycle')
141142
mpl_stylesheet['axes.prop_cycle'] = cycler.cycler('color', colors)
143+
except ImportError:
144+
_CYCLER_INSTALLED = False
142145

143146

144147
def _get_standard_kind(kind):
@@ -884,7 +887,6 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,
884887
self.by = by
885888

886889
self.kind = kind
887-
888890
self.sort_columns = sort_columns
889891

890892
self.subplots = subplots
@@ -959,7 +961,9 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,
959961

960962
self.table = table
961963

962-
self.kwds = kwds
964+
# init and validatecycler keyword
965+
# ToDo: _validate_color_arg may change kwds to list
966+
self.kwds = self._init_cyclic_option(kwds)
963967

964968
self._validate_color_args()
965969

@@ -993,6 +997,60 @@ def _validate_color_args(self):
993997
" use one or the other or pass 'style' "
994998
"without a color symbol")
995999

1000+
def _init_cyclic_option(self, kwds):
1001+
"""
1002+
Convert passed kwds to cycler instance
1003+
"""
1004+
if not _CYCLER_INSTALLED:
1005+
return kwds
1006+
1007+
option = {}
1008+
for key, value in compat.iteritems(kwds):
1009+
if isinstance(value, cycler.Cycler):
1010+
cycler_keys = value.keys
1011+
if len(cycler_keys) > 1:
1012+
msg = ("cycler should only contain "
1013+
"passed keyword '{0}': {1}")
1014+
raise ValueError(msg.format(key, value))
1015+
if key not in cycler_keys:
1016+
msg = ("cycler must contain "
1017+
"passed keyword '{0}': {1}")
1018+
raise ValueError(msg.format(key, value))
1019+
1020+
elif isinstance(value, list):
1021+
# instanciate cycler
1022+
# to do: check mpl kwds which should handle list as it is
1023+
cycler_value = cycler.cycler(key, value)
1024+
value = cycler_value
1025+
1026+
option[key] = value
1027+
return option
1028+
1029+
def _get_cyclic_option(self, kwds, num):
1030+
"""
1031+
Get num-th element of cycler contained in passed kwds.
1032+
"""
1033+
if not _CYCLER_INSTALLED:
1034+
return kwds
1035+
1036+
option = {}
1037+
for key, value in compat.iteritems(kwds):
1038+
if isinstance(value, cycler.Cycler):
1039+
# cycler() will implicitly loop, cycler will not
1040+
# cycler 0.10 or later is required
1041+
for i, v in enumerate(value()):
1042+
if i == num:
1043+
try:
1044+
option[key] = v[key]
1045+
except KeyError:
1046+
msg = ("cycler doesn't contain required "
1047+
"key '{0}': {1}")
1048+
raise ValueError(msg.format(key, value))
1049+
break
1050+
else:
1051+
option[key] = value
1052+
return option
1053+
9961054
def _iter_data(self, data=None, keep_index=False, fillna=None):
9971055
if data is None:
9981056
data = self.data
@@ -1013,6 +1071,9 @@ def _iter_data(self, data=None, keep_index=False, fillna=None):
10131071

10141072
@property
10151073
def nseries(self):
1074+
"""
1075+
Number of columns to be plotted. If data is a Series, return 1.
1076+
"""
10161077
if self.data.ndim == 1:
10171078
return 1
10181079
else:
@@ -1161,6 +1222,7 @@ def _post_plot_logic_common(self, ax, data):
11611222
self._apply_axis_properties(ax.xaxis, rot=self.rot,
11621223
fontsize=self.fontsize)
11631224
self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
1225+
11641226
elif self.orientation == 'horizontal':
11651227
if self._need_to_set_index:
11661228
yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
@@ -1696,8 +1758,10 @@ def _make_plot(self):
16961758
colors = self._get_colors()
16971759
for i, (label, y) in enumerate(it):
16981760
ax = self._get_ax(i)
1761+
16991762
kwds = self.kwds.copy()
17001763
style, kwds = self._apply_style_colors(colors, kwds, i, label)
1764+
kwds = self._get_cyclic_option(kwds, i)
17011765

17021766
errors = self._get_errorbars(label=label, index=i)
17031767
kwds = dict(kwds, **errors)
@@ -1829,13 +1893,20 @@ def __init__(self, data, **kwargs):
18291893
if self.logy or self.loglog:
18301894
raise ValueError("Log-y scales are not supported in area plot")
18311895

1896+
# kwds should not be passed to line
1897+
_fill_only_kwds = ['hatch']
1898+
18321899
@classmethod
18331900
def _plot(cls, ax, x, y, style=None, column_num=None,
18341901
stacking_id=None, is_errorbar=False, **kwds):
18351902
if column_num == 0:
18361903
cls._initialize_stacker(ax, stacking_id, len(y))
18371904
y_values = cls._get_stacked_values(ax, stacking_id, y, kwds['label'])
1838-
lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds)
1905+
1906+
line_kwds = kwds.copy()
1907+
for attr in cls._fill_only_kwds:
1908+
line_kwds.pop(attr, None)
1909+
lines = MPLPlot._plot(ax, x, y_values, style=style, **line_kwds)
18391910

18401911
# get data from the line to get coordinates for fill_between
18411912
xdata, y_values = lines[0].get_data(orig=False)
@@ -1939,6 +2010,8 @@ def _make_plot(self):
19392010
kwds = self.kwds.copy()
19402011
kwds['color'] = colors[i % ncolors]
19412012

2013+
kwds = self._get_cyclic_option(kwds, i)
2014+
19422015
errors = self._get_errorbars(label=label, index=i)
19432016
kwds = dict(kwds, **errors)
19442017

@@ -2064,6 +2137,7 @@ def _make_plot(self):
20642137
ax = self._get_ax(i)
20652138

20662139
kwds = self.kwds.copy()
2140+
kwds = self._get_cyclic_option(kwds, i)
20672141

20682142
label = pprint_thing(label)
20692143
kwds['label'] = label
@@ -2180,6 +2254,7 @@ def _make_plot(self):
21802254
ax.set_ylabel(label)
21812255

21822256
kwds = self.kwds.copy()
2257+
kwds = self._get_cyclic_option(kwds, i)
21832258

21842259
def blank_labeler(label, value):
21852260
if value == 0:
@@ -2320,6 +2395,7 @@ def _make_plot(self):
23202395
for i, (label, y) in enumerate(self._iter_data()):
23212396
ax = self._get_ax(i)
23222397
kwds = self.kwds.copy()
2398+
kwds = self._get_cyclic_option(kwds, i)
23232399

23242400
ret, bp = self._plot(ax, y, column_num=i,
23252401
return_type=self.return_type, **kwds)
@@ -2332,6 +2408,7 @@ def _make_plot(self):
23322408
y = self.data.values.T
23332409
ax = self._get_ax(0)
23342410
kwds = self.kwds.copy()
2411+
kwds = self._get_cyclic_option(kwds, 0)
23352412

23362413
ret, bp = self._plot(ax, y, column_num=0,
23372414
return_type=self.return_type, **kwds)

0 commit comments

Comments
 (0)