Skip to content

Commit 75bbf4c

Browse files
committed
CLN: Simplify LinePlot flow
1 parent 34cecd8 commit 75bbf4c

File tree

2 files changed

+68
-140
lines changed

2 files changed

+68
-140
lines changed

pandas/tools/plotting.py

+66-114
Original file line numberDiff line numberDiff line change
@@ -755,9 +755,9 @@ class MPLPlot(object):
755755
_default_rot = 0
756756

757757
_pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog',
758-
'mark_right']
758+
'mark_right', 'stacked']
759759
_attr_defaults = {'logy': False, 'logx': False, 'loglog': False,
760-
'mark_right': True}
760+
'mark_right': True, 'stacked': False}
761761

762762
def __init__(self, data, kind=None, by=None, subplots=False, sharex=True,
763763
sharey=False, use_index=True,
@@ -1080,7 +1080,6 @@ def _make_legend(self):
10801080
for ax in self.axes:
10811081
ax.legend(loc='best')
10821082

1083-
10841083
def _get_ax_legend(self, ax):
10851084
leg = ax.get_legend()
10861085
other_ax = (getattr(ax, 'right_ax', None) or
@@ -1139,12 +1138,22 @@ def _get_plot_function(self):
11391138
Returns the matplotlib plotting function (plot or errorbar) based on
11401139
the presence of errorbar keywords.
11411140
'''
1142-
1143-
if all(e is None for e in self.errors.values()):
1144-
plotf = self.plt.Axes.plot
1145-
else:
1146-
plotf = self.plt.Axes.errorbar
1147-
1141+
errorbar = any(e is not None for e in self.errors.values())
1142+
def plotf(ax, x, y, style=None, **kwds):
1143+
mask = com.isnull(y)
1144+
if mask.any():
1145+
y = np.ma.array(y)
1146+
y = np.ma.masked_where(mask, y)
1147+
1148+
if errorbar:
1149+
return self.plt.Axes.errorbar(ax, x, y, **kwds)
1150+
else:
1151+
# prevent style kwarg from going to errorbar, where it is unsupported
1152+
if style is not None:
1153+
args = (ax, x, y, style)
1154+
else:
1155+
args = (ax, x, y)
1156+
return self.plt.Axes.plot(*args, **kwds)
11481157
return plotf
11491158

11501159
def _get_index_name(self):
@@ -1472,11 +1481,9 @@ def _post_plot_logic(self):
14721481
class LinePlot(MPLPlot):
14731482

14741483
def __init__(self, data, **kwargs):
1475-
self.stacked = kwargs.pop('stacked', False)
1476-
if self.stacked:
1477-
data = data.fillna(value=0)
1478-
14791484
MPLPlot.__init__(self, data, **kwargs)
1485+
if self.stacked:
1486+
self.data = self.data.fillna(value=0)
14801487
self.x_compat = plot_params['x_compat']
14811488
if 'x_compat' in self.kwds:
14821489
self.x_compat = bool(self.kwds.pop('x_compat'))
@@ -1533,56 +1540,39 @@ def _is_ts_plot(self):
15331540
return not self.x_compat and self.use_index and self._use_dynamic_x()
15341541

15351542
def _make_plot(self):
1536-
self._pos_prior = np.zeros(len(self.data))
1537-
self._neg_prior = np.zeros(len(self.data))
1543+
self._initialize_prior(len(self.data))
15381544

15391545
if self._is_ts_plot():
15401546
data = self._maybe_convert_index(self.data)
1541-
self._make_ts_plot(data)
1547+
x = data.index # dummy, not used
1548+
plotf = self._get_ts_plot_function()
1549+
it = self._iter_data(data=data, keep_index=True)
15421550
else:
15431551
x = self._get_xticks(convert_period=True)
1544-
15451552
plotf = self._get_plot_function()
1546-
colors = self._get_colors()
1547-
1548-
for i, (label, y) in enumerate(self._iter_data()):
1549-
ax = self._get_ax(i)
1550-
style = self._get_style(i, label)
1551-
kwds = self.kwds.copy()
1552-
self._maybe_add_color(colors, kwds, style, i)
1553+
it = self._iter_data()
15531554

1554-
errors = self._get_errorbars(label=label, index=i)
1555-
kwds = dict(kwds, **errors)
1556-
1557-
label = com.pprint_thing(label) # .encode('utf-8')
1558-
kwds['label'] = label
1559-
1560-
y_values = self._get_stacked_values(y, label)
1561-
1562-
if not self.stacked:
1563-
mask = com.isnull(y_values)
1564-
if mask.any():
1565-
y_values = np.ma.array(y_values)
1566-
y_values = np.ma.masked_where(mask, y_values)
1555+
colors = self._get_colors()
1556+
for i, (label, y) in enumerate(it):
1557+
ax = self._get_ax(i)
1558+
style = self._get_style(i, label)
1559+
kwds = self.kwds.copy()
1560+
self._maybe_add_color(colors, kwds, style, i)
15671561

1568-
# prevent style kwarg from going to errorbar, where it is unsupported
1569-
if style is not None and plotf.__name__ != 'errorbar':
1570-
args = (ax, x, y_values, style)
1571-
else:
1572-
args = (ax, x, y_values)
1562+
errors = self._get_errorbars(label=label, index=i)
1563+
kwds = dict(kwds, **errors)
15731564

1574-
newlines = plotf(*args, **kwds)
1575-
self._add_legend_handle(newlines[0], label, index=i)
1565+
label = com.pprint_thing(label) # .encode('utf-8')
1566+
kwds['label'] = label
1567+
y_values = self._get_stacked_values(y, label)
15761568

1577-
if self.stacked and not self.subplots:
1578-
if (y >= 0).all():
1579-
self._pos_prior += y
1580-
elif (y <= 0).all():
1581-
self._neg_prior += y
1569+
newlines = plotf(ax, x, y_values, style=style, **kwds)
1570+
self._update_prior(y)
1571+
self._add_legend_handle(newlines[0], label, index=i)
15821572

1583-
lines = _get_all_lines(ax)
1584-
left, right = _get_xlim(lines)
1585-
ax.set_xlim(left, right)
1573+
lines = _get_all_lines(ax)
1574+
left, right = _get_xlim(lines)
1575+
ax.set_xlim(left, right)
15861576

15871577
def _get_stacked_values(self, y, label):
15881578
if self.stacked:
@@ -1599,46 +1589,26 @@ def _get_stacked_values(self, y, label):
15991589
def _get_ts_plot_function(self):
16001590
from pandas.tseries.plotting import tsplot
16011591
plotf = self._get_plot_function()
1602-
1603-
def _plot(data, ax, label, style, **kwds):
1604-
# errorbar function does not support style argument
1605-
if plotf.__name__ == 'errorbar':
1606-
lines = tsplot(data, plotf, ax=ax, label=label,
1607-
**kwds)
1608-
return lines
1609-
else:
1610-
lines = tsplot(data, plotf, ax=ax, label=label,
1611-
style=style, **kwds)
1612-
return lines
1592+
def _plot(ax, x, data, style=None, **kwds):
1593+
# accept x to be consistent with normal plot func,
1594+
# x is not passed to tsplot as it uses data.index as x coordinate
1595+
lines = tsplot(data, plotf, ax=ax, style=style, **kwds)
1596+
return lines
16131597
return _plot
16141598

1615-
def _make_ts_plot(self, data, **kwargs):
1616-
colors = self._get_colors()
1617-
plotf = self._get_ts_plot_function()
1618-
1619-
it = self._iter_data(data=data, keep_index=True)
1620-
for i, (label, y) in enumerate(it):
1621-
ax = self._get_ax(i)
1622-
style = self._get_style(i, label)
1623-
kwds = self.kwds.copy()
1624-
1625-
self._maybe_add_color(colors, kwds, style, i)
1626-
1627-
errors = self._get_errorbars(label=label, index=i, xerr=False)
1628-
kwds = dict(kwds, **errors)
1629-
1630-
label = com.pprint_thing(label)
1631-
1632-
y_values = self._get_stacked_values(y, label)
1633-
1634-
newlines = plotf(y_values, ax, label, style, **kwds)
1635-
self._add_legend_handle(newlines[0], label, index=i)
1599+
def _initialize_prior(self, n):
1600+
self._pos_prior = np.zeros(n)
1601+
self._neg_prior = np.zeros(n)
16361602

1637-
if self.stacked and not self.subplots:
1638-
if (y >= 0).all():
1639-
self._pos_prior += y
1640-
elif (y <= 0).all():
1641-
self._neg_prior += y
1603+
def _update_prior(self, y):
1604+
if self.stacked and not self.subplots:
1605+
# tsplot resample may changedata length
1606+
if len(self._pos_prior) != len(y):
1607+
self._initialize_prior(len(y))
1608+
if (y >= 0).all():
1609+
self._pos_prior += y
1610+
elif (y <= 0).all():
1611+
self._neg_prior += y
16421612

16431613
def _maybe_convert_index(self, data):
16441614
# tsplot converts automatically, but don't want to convert index
@@ -1707,30 +1677,25 @@ def _get_plot_function(self):
17071677
if self.logy or self.loglog:
17081678
raise ValueError("Log-y scales are not supported in area plot")
17091679
else:
1710-
f = LinePlot._get_plot_function(self)
1711-
1712-
def plotf(*args, **kwds):
1713-
lines = f(*args, **kwds)
1680+
f = MPLPlot._get_plot_function(self)
1681+
def plotf(ax, x, y, style=None, **kwds):
1682+
lines = f(ax, x, y, style=style, **kwds)
17141683

1684+
# get data from the line
17151685
# insert fill_between starting point
1716-
y = args[2]
1686+
xdata, y_values = lines[0].get_data(orig=False)
1687+
17171688
if (y >= 0).all():
17181689
start = self._pos_prior
17191690
elif (y <= 0).all():
17201691
start = self._neg_prior
17211692
else:
17221693
start = np.zeros(len(y))
17231694

1724-
# get x data from the line
1725-
# to retrieve x coodinates of tsplot
1726-
xdata = lines[0].get_data()[0]
1727-
# remove style
1728-
args = (args[0], xdata, start, y)
1729-
17301695
if not 'color' in kwds:
17311696
kwds['color'] = lines[0].get_color()
17321697

1733-
self.plt.Axes.fill_between(*args, **kwds)
1698+
self.plt.Axes.fill_between(ax, xdata, start, y_values, **kwds)
17341699
return lines
17351700

17361701
return plotf
@@ -1746,15 +1711,6 @@ def _add_legend_handle(self, handle, label, index=None):
17461711
def _post_plot_logic(self):
17471712
LinePlot._post_plot_logic(self)
17481713

1749-
if self._is_ts_plot():
1750-
pass
1751-
else:
1752-
if self.xlim is None:
1753-
for ax in self.axes:
1754-
lines = _get_all_lines(ax)
1755-
left, right = _get_xlim(lines)
1756-
ax.set_xlim(left, right)
1757-
17581714
if self.ylim is None:
17591715
if (self.data >= 0).all().all():
17601716
for ax in self.axes:
@@ -1769,12 +1725,8 @@ class BarPlot(MPLPlot):
17691725
_default_rot = {'bar': 90, 'barh': 0}
17701726

17711727
def __init__(self, data, **kwargs):
1772-
self.stacked = kwargs.pop('stacked', False)
1773-
17741728
self.bar_width = kwargs.pop('width', 0.5)
1775-
17761729
pos = kwargs.pop('position', 0.5)
1777-
17781730
kwargs.setdefault('align', 'center')
17791731
self.tick_pos = np.arange(len(data))
17801732

pandas/tseries/plotting.py

+2-26
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from pandas.tseries.converter import (PeriodConverter, TimeSeries_DateLocator,
1919
TimeSeries_DateFormatter)
2020

21-
from pandas.tools.plotting import _get_all_lines, _get_xlim
22-
2321
#----------------------------------------------------------------------
2422
# Plotting functions and monkey patches
2523

@@ -59,25 +57,15 @@ def tsplot(series, plotf, **kwargs):
5957
# Set ax with freq info
6058
_decorate_axes(ax, freq, kwargs)
6159

62-
# mask missing values
63-
args = _maybe_mask(series)
64-
6560
# how to make sure ax.clear() flows through?
6661
if not hasattr(ax, '_plot_data'):
6762
ax._plot_data = []
6863
ax._plot_data.append((series, kwargs))
6964

70-
# styles
71-
style = kwargs.pop('style', None)
72-
if style is not None:
73-
args.append(style)
74-
75-
lines = plotf(ax, *args, **kwargs)
65+
lines = plotf(ax, series.index, series.values, **kwargs)
7666

7767
# set date formatter, locators and rescale limits
7868
format_dateaxis(ax, ax.freq)
79-
left, right = _get_xlim(_get_all_lines(ax))
80-
ax.set_xlim(left, right)
8169

8270
# x and y coord info
8371
ax.format_coord = lambda t, y: ("t = {0} "
@@ -165,8 +153,7 @@ def _replot_ax(ax, freq, plotf, kwargs):
165153
idx = series.index.asfreq(freq, how='S')
166154
series.index = idx
167155
ax._plot_data.append(series)
168-
args = _maybe_mask(series)
169-
lines.append(plotf(ax, *args, **kwds)[0])
156+
lines.append(plotf(ax, series.index, series.values, **kwds)[0])
170157
labels.append(com.pprint_thing(series.name))
171158

172159
return lines, labels
@@ -184,17 +171,6 @@ def _decorate_axes(ax, freq, kwargs):
184171
ax.date_axis_info = None
185172

186173

187-
def _maybe_mask(series):
188-
mask = isnull(series)
189-
if mask.any():
190-
masked_array = np.ma.array(series.values)
191-
masked_array = np.ma.masked_where(mask, masked_array)
192-
args = [series.index, masked_array]
193-
else:
194-
args = [series.index, series.values]
195-
return args
196-
197-
198174
def _get_freq(ax, series):
199175
# get frequency from data
200176
freq = getattr(series.index, 'freq', None)

0 commit comments

Comments
 (0)