Skip to content

CLN: Simplify LinePlot flow #7717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 66 additions & 114 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,9 @@ class MPLPlot(object):
_default_rot = 0

_pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog',
'mark_right']
'mark_right', 'stacked']
_attr_defaults = {'logy': False, 'logx': False, 'loglog': False,
'mark_right': True}
'mark_right': True, 'stacked': False}

def __init__(self, data, kind=None, by=None, subplots=False, sharex=True,
sharey=False, use_index=True,
Expand Down Expand Up @@ -1080,7 +1080,6 @@ def _make_legend(self):
for ax in self.axes:
ax.legend(loc='best')


def _get_ax_legend(self, ax):
leg = ax.get_legend()
other_ax = (getattr(ax, 'right_ax', None) or
Expand Down Expand Up @@ -1139,12 +1138,22 @@ def _get_plot_function(self):
Returns the matplotlib plotting function (plot or errorbar) based on
the presence of errorbar keywords.
'''

if all(e is None for e in self.errors.values()):
plotf = self.plt.Axes.plot
else:
plotf = self.plt.Axes.errorbar

errorbar = any(e is not None for e in self.errors.values())
def plotf(ax, x, y, style=None, **kwds):
mask = com.isnull(y)
if mask.any():
y = np.ma.array(y)
y = np.ma.masked_where(mask, y)

if errorbar:
return self.plt.Axes.errorbar(ax, x, y, **kwds)
else:
# prevent style kwarg from going to errorbar, where it is unsupported
if style is not None:
args = (ax, x, y, style)
else:
args = (ax, x, y)
return self.plt.Axes.plot(*args, **kwds)
return plotf

def _get_index_name(self):
Expand Down Expand Up @@ -1472,11 +1481,9 @@ def _post_plot_logic(self):
class LinePlot(MPLPlot):

def __init__(self, data, **kwargs):
self.stacked = kwargs.pop('stacked', False)
if self.stacked:
data = data.fillna(value=0)

MPLPlot.__init__(self, data, **kwargs)
if self.stacked:
self.data = self.data.fillna(value=0)
self.x_compat = plot_params['x_compat']
if 'x_compat' in self.kwds:
self.x_compat = bool(self.kwds.pop('x_compat'))
Expand Down Expand Up @@ -1533,56 +1540,39 @@ def _is_ts_plot(self):
return not self.x_compat and self.use_index and self._use_dynamic_x()

def _make_plot(self):
self._pos_prior = np.zeros(len(self.data))
self._neg_prior = np.zeros(len(self.data))
self._initialize_prior(len(self.data))

if self._is_ts_plot():
data = self._maybe_convert_index(self.data)
self._make_ts_plot(data)
x = data.index # dummy, not used
plotf = self._get_ts_plot_function()
it = self._iter_data(data=data, keep_index=True)
else:
x = self._get_xticks(convert_period=True)

plotf = self._get_plot_function()
colors = self._get_colors()

for i, (label, y) in enumerate(self._iter_data()):
ax = self._get_ax(i)
style = self._get_style(i, label)
kwds = self.kwds.copy()
self._maybe_add_color(colors, kwds, style, i)
it = self._iter_data()

errors = self._get_errorbars(label=label, index=i)
kwds = dict(kwds, **errors)

label = com.pprint_thing(label) # .encode('utf-8')
kwds['label'] = label

y_values = self._get_stacked_values(y, label)

if not self.stacked:
mask = com.isnull(y_values)
if mask.any():
y_values = np.ma.array(y_values)
y_values = np.ma.masked_where(mask, y_values)
colors = self._get_colors()
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)
style = self._get_style(i, label)
kwds = self.kwds.copy()
self._maybe_add_color(colors, kwds, style, i)

# prevent style kwarg from going to errorbar, where it is unsupported
if style is not None and plotf.__name__ != 'errorbar':
args = (ax, x, y_values, style)
else:
args = (ax, x, y_values)
errors = self._get_errorbars(label=label, index=i)
kwds = dict(kwds, **errors)

newlines = plotf(*args, **kwds)
self._add_legend_handle(newlines[0], label, index=i)
label = com.pprint_thing(label) # .encode('utf-8')
kwds['label'] = label
y_values = self._get_stacked_values(y, label)

if self.stacked and not self.subplots:
if (y >= 0).all():
self._pos_prior += y
elif (y <= 0).all():
self._neg_prior += y
newlines = plotf(ax, x, y_values, style=style, **kwds)
self._update_prior(y)
self._add_legend_handle(newlines[0], label, index=i)

lines = _get_all_lines(ax)
left, right = _get_xlim(lines)
ax.set_xlim(left, right)
lines = _get_all_lines(ax)
left, right = _get_xlim(lines)
ax.set_xlim(left, right)

def _get_stacked_values(self, y, label):
if self.stacked:
Expand All @@ -1599,46 +1589,26 @@ def _get_stacked_values(self, y, label):
def _get_ts_plot_function(self):
from pandas.tseries.plotting import tsplot
plotf = self._get_plot_function()

def _plot(data, ax, label, style, **kwds):
# errorbar function does not support style argument
if plotf.__name__ == 'errorbar':
lines = tsplot(data, plotf, ax=ax, label=label,
**kwds)
return lines
else:
lines = tsplot(data, plotf, ax=ax, label=label,
style=style, **kwds)
return lines
def _plot(ax, x, data, style=None, **kwds):
# accept x to be consistent with normal plot func,
# x is not passed to tsplot as it uses data.index as x coordinate
lines = tsplot(data, plotf, ax=ax, style=style, **kwds)
return lines
return _plot

def _make_ts_plot(self, data, **kwargs):
colors = self._get_colors()
plotf = self._get_ts_plot_function()

it = self._iter_data(data=data, keep_index=True)
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)
style = self._get_style(i, label)
kwds = self.kwds.copy()

self._maybe_add_color(colors, kwds, style, i)

errors = self._get_errorbars(label=label, index=i, xerr=False)
kwds = dict(kwds, **errors)

label = com.pprint_thing(label)

y_values = self._get_stacked_values(y, label)

newlines = plotf(y_values, ax, label, style, **kwds)
self._add_legend_handle(newlines[0], label, index=i)
def _initialize_prior(self, n):
self._pos_prior = np.zeros(n)
self._neg_prior = np.zeros(n)

if self.stacked and not self.subplots:
if (y >= 0).all():
self._pos_prior += y
elif (y <= 0).all():
self._neg_prior += y
def _update_prior(self, y):
if self.stacked and not self.subplots:
# tsplot resample may changedata length
if len(self._pos_prior) != len(y):
self._initialize_prior(len(y))
if (y >= 0).all():
self._pos_prior += y
elif (y <= 0).all():
self._neg_prior += y

def _maybe_convert_index(self, data):
# tsplot converts automatically, but don't want to convert index
Expand Down Expand Up @@ -1707,30 +1677,25 @@ def _get_plot_function(self):
if self.logy or self.loglog:
raise ValueError("Log-y scales are not supported in area plot")
else:
f = LinePlot._get_plot_function(self)

def plotf(*args, **kwds):
lines = f(*args, **kwds)
f = MPLPlot._get_plot_function(self)
def plotf(ax, x, y, style=None, **kwds):
lines = f(ax, x, y, style=style, **kwds)

# get data from the line
# insert fill_between starting point
y = args[2]
xdata, y_values = lines[0].get_data(orig=False)

if (y >= 0).all():
start = self._pos_prior
elif (y <= 0).all():
start = self._neg_prior
else:
start = np.zeros(len(y))

# get x data from the line
# to retrieve x coodinates of tsplot
xdata = lines[0].get_data()[0]
# remove style
args = (args[0], xdata, start, y)

if not 'color' in kwds:
kwds['color'] = lines[0].get_color()

self.plt.Axes.fill_between(*args, **kwds)
self.plt.Axes.fill_between(ax, xdata, start, y_values, **kwds)
return lines

return plotf
Expand All @@ -1746,15 +1711,6 @@ def _add_legend_handle(self, handle, label, index=None):
def _post_plot_logic(self):
LinePlot._post_plot_logic(self)

if self._is_ts_plot():
pass
else:
if self.xlim is None:
for ax in self.axes:
lines = _get_all_lines(ax)
left, right = _get_xlim(lines)
ax.set_xlim(left, right)

if self.ylim is None:
if (self.data >= 0).all().all():
for ax in self.axes:
Expand All @@ -1769,12 +1725,8 @@ class BarPlot(MPLPlot):
_default_rot = {'bar': 90, 'barh': 0}

def __init__(self, data, **kwargs):
self.stacked = kwargs.pop('stacked', False)

self.bar_width = kwargs.pop('width', 0.5)

pos = kwargs.pop('position', 0.5)

kwargs.setdefault('align', 'center')
self.tick_pos = np.arange(len(data))

Expand Down
28 changes: 2 additions & 26 deletions pandas/tseries/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from pandas.tseries.converter import (PeriodConverter, TimeSeries_DateLocator,
TimeSeries_DateFormatter)

from pandas.tools.plotting import _get_all_lines, _get_xlim

#----------------------------------------------------------------------
# Plotting functions and monkey patches

Expand Down Expand Up @@ -59,25 +57,15 @@ def tsplot(series, plotf, **kwargs):
# Set ax with freq info
_decorate_axes(ax, freq, kwargs)

# mask missing values
args = _maybe_mask(series)

# how to make sure ax.clear() flows through?
if not hasattr(ax, '_plot_data'):
ax._plot_data = []
ax._plot_data.append((series, kwargs))

# styles
style = kwargs.pop('style', None)
if style is not None:
args.append(style)

lines = plotf(ax, *args, **kwargs)
lines = plotf(ax, series.index, series.values, **kwargs)

# set date formatter, locators and rescale limits
format_dateaxis(ax, ax.freq)
left, right = _get_xlim(_get_all_lines(ax))
ax.set_xlim(left, right)

# x and y coord info
ax.format_coord = lambda t, y: ("t = {0} "
Expand Down Expand Up @@ -165,8 +153,7 @@ def _replot_ax(ax, freq, plotf, kwargs):
idx = series.index.asfreq(freq, how='S')
series.index = idx
ax._plot_data.append(series)
args = _maybe_mask(series)
lines.append(plotf(ax, *args, **kwds)[0])
lines.append(plotf(ax, series.index, series.values, **kwds)[0])
labels.append(com.pprint_thing(series.name))

return lines, labels
Expand All @@ -184,17 +171,6 @@ def _decorate_axes(ax, freq, kwargs):
ax.date_axis_info = None


def _maybe_mask(series):
mask = isnull(series)
if mask.any():
masked_array = np.ma.array(series.values)
masked_array = np.ma.masked_where(mask, masked_array)
args = [series.index, masked_array]
else:
args = [series.index, series.values]
return args


def _get_freq(ax, series):
# get frequency from data
freq = getattr(series.index, 'freq', None)
Expand Down