Skip to content

CLN: simplify series plotting #6876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 27, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ def test_plot(self):
_check_plot_works(self.ts.plot, style='.', loglog=True)
_check_plot_works(self.ts[:10].plot, kind='bar')
_check_plot_works(self.iseries.plot)
_check_plot_works(self.series[:5].plot, kind='bar')
_check_plot_works(self.series[:5].plot, kind='line')
_check_plot_works(self.series[:5].plot, kind='barh')

for kind in plotting._common_kinds:
_check_plot_works(self.series[:5].plot, kind=kind)

_check_plot_works(self.series[:10].plot, kind='barh')
_check_plot_works(Series(randn(10)).plot, kind='bar', color='black')

Expand Down Expand Up @@ -250,25 +251,19 @@ def test_bootstrap_plot(self):

def test_invalid_plot_data(self):
s = Series(list('abcd'))
kinds = 'line', 'bar', 'barh', 'kde', 'density'

for kind in kinds:
for kind in plotting._common_kinds:
with tm.assertRaises(TypeError):
s.plot(kind=kind)

@slow
def test_valid_object_plot(self):
s = Series(lrange(10), dtype=object)
kinds = 'line', 'bar', 'barh', 'kde', 'density'

for kind in kinds:
for kind in plotting._common_kinds:
_check_plot_works(s.plot, kind=kind)

def test_partially_invalid_plot_data(self):
s = Series(['a', 'b', 1.0, 2])
kinds = 'line', 'bar', 'barh', 'kde', 'density'

for kind in kinds:
for kind in plotting._common_kinds:
with tm.assertRaises(TypeError):
s.plot(kind=kind)

Expand Down Expand Up @@ -1247,19 +1242,17 @@ def test_unordered_ts(self):
assert_array_equal(ydata, np.array([1.0, 2.0, 3.0]))

def test_all_invalid_plot_data(self):
kinds = 'line', 'bar', 'barh', 'kde', 'density'
df = DataFrame(list('abcd'))
for kind in kinds:
for kind in plotting._common_kinds:
with tm.assertRaises(TypeError):
df.plot(kind=kind)

@slow
def test_partially_invalid_plot_data(self):
with tm.RNGContext(42):
kinds = 'line', 'bar', 'barh', 'kde', 'density'
df = DataFrame(randn(10, 2), dtype=object)
df[np.random.rand(df.shape[0]) > 0.5] = 'a'
for kind in kinds:
for kind in plotting._common_kinds:
with tm.assertRaises(TypeError):
df.plot(kind=kind)

Expand Down
114 changes: 48 additions & 66 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,25 +887,31 @@ def _validate_color_args(self):
" use one or the other or pass 'style' "
"without a color symbol")

def _iter_data(self):
from pandas.core.frame import DataFrame
if isinstance(self.data, (Series, np.ndarray)):
yield self.label, np.asarray(self.data)
elif isinstance(self.data, DataFrame):
df = self.data
def _iter_data(self, data=None, keep_index=False):
if data is None:
data = self.data

from pandas.core.frame import DataFrame
if isinstance(data, (Series, np.ndarray)):
if keep_index is True:
yield self.label, data
else:
yield self.label, np.asarray(data)
elif isinstance(data, DataFrame):
if self.sort_columns:
columns = com._try_sort(df.columns)
columns = com._try_sort(data.columns)
else:
columns = df.columns
columns = data.columns

for col in columns:
# # is this right?
# empty = df[col].count() == 0
# values = df[col].values if not empty else np.zeros(len(df))

values = df[col].values
yield col, values
if keep_index is True:
yield col, data[col]
else:
yield col, data[col].values

@property
def nseries(self):
Expand Down Expand Up @@ -1593,38 +1599,26 @@ def _plot(data, col_num, ax, label, style, **kwds):

self._add_legend_handle(newlines[0], label, index=col_num)

if isinstance(data, Series):
ax = self._get_ax(0) # self.axes[0]
style = self.style or ''
label = com.pprint_thing(self.label)
it = self._iter_data(data=data, keep_index=True)
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)
style = self._get_style(i, label)
kwds = self.kwds.copy()
self._maybe_add_color(colors, kwds, style, 0)

if 'yerr' in kwds:
kwds['yerr'] = kwds['yerr'][0]

_plot(data, 0, ax, label, self.style, **kwds)

else:
for i, col in enumerate(data.columns):
label = com.pprint_thing(col)
ax = self._get_ax(i)
style = self._get_style(i, col)
kwds = self.kwds.copy()

self._maybe_add_color(colors, kwds, style, i)
self._maybe_add_color(colors, kwds, style, i)

# key-matched DataFrame of errors
if 'yerr' in kwds:
yerr = kwds['yerr']
if isinstance(yerr, (DataFrame, dict)):
if col in yerr.keys():
kwds['yerr'] = yerr[col]
else: del kwds['yerr']
else:
kwds['yerr'] = yerr[i]
# key-matched DataFrame of errors
if 'yerr' in kwds:
yerr = kwds['yerr']
if isinstance(yerr, (DataFrame, dict)):
if label in yerr.keys():
kwds['yerr'] = yerr[label]
else: del kwds['yerr']
else:
kwds['yerr'] = yerr[i]

_plot(data[col], i, ax, label, style, **kwds)
label = com.pprint_thing(label)
_plot(y, i, ax, label, style, **kwds)

def _maybe_convert_index(self, data):
# tsplot converts automatically, but don't want to convert index
Expand Down Expand Up @@ -1828,6 +1822,16 @@ class BoxPlot(MPLPlot):
class HistPlot(MPLPlot):
pass

# kinds supported by both dataframe and series
_common_kinds = ['line', 'bar', 'barh', 'kde', 'density']
# kinds supported by dataframe
_dataframe_kinds = ['scatter', 'hexbin']
_all_kinds = _common_kinds + _dataframe_kinds

_plot_klass = {'line': LinePlot, 'bar': BarPlot, 'barh': BarPlot,
'kde': KdePlot,
'scatter': ScatterPlot, 'hexbin': HexBinPlot}


def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
sharey=False, use_index=True, figsize=None, grid=None,
Expand Down Expand Up @@ -1921,21 +1925,14 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
is a function of one argument that reduces all the values in a bin to
a single number (e.g. `mean`, `max`, `sum`, `std`).
"""

kind = _get_standard_kind(kind.lower().strip())
if kind == 'line':
klass = LinePlot
elif kind in ('bar', 'barh'):
klass = BarPlot
elif kind == 'kde':
klass = KdePlot
elif kind == 'scatter':
klass = ScatterPlot
elif kind == 'hexbin':
klass = HexBinPlot
if kind in _dataframe_kinds or kind in _common_kinds:
klass = _plot_klass[kind]
else:
raise ValueError('Invalid chart type given %s' % kind)

if kind == 'scatter':
if kind in _dataframe_kinds:
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
rot=rot,legend=legend, ax=ax, style=style,
fontsize=fontsize, use_index=use_index, sharex=sharex,
Expand All @@ -1944,16 +1941,6 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
figsize=figsize, logx=logx, logy=logy,
sort_columns=sort_columns, secondary_y=secondary_y,
**kwds)
elif kind == 'hexbin':
C = kwds.pop('C', None) # remove from kwargs so we can set default
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
rot=rot,legend=legend, ax=ax, style=style,
fontsize=fontsize, use_index=use_index, sharex=sharex,
sharey=sharey, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, title=title, grid=grid,
figsize=figsize, logx=logx, logy=logy,
sort_columns=sort_columns, secondary_y=secondary_y,
C=C, **kwds)
else:
if x is not None:
if com.is_integer(x) and not frame.columns.holds_integer():
Expand Down Expand Up @@ -2051,14 +2038,9 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None,
See matplotlib documentation online for more on this subject
"""

from pandas import DataFrame
kind = _get_standard_kind(kind.lower().strip())
if kind == 'line':
klass = LinePlot
elif kind in ('bar', 'barh'):
klass = BarPlot
elif kind == 'kde':
klass = KdePlot
if kind in _common_kinds:
klass = _plot_klass[kind]
else:
raise ValueError('Invalid chart type given %s' % kind)

Expand Down