diff --git a/doc/source/release.rst b/doc/source/release.rst index 34cc4e499a0d5..c85e86caa8114 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -210,6 +210,7 @@ API Changes - Default export for ``to_clipboard`` is now csv with a sep of `\t` for compat (:issue:`3368`) - ``at`` now will enlarge the object inplace (and return the same) (:issue:`2578`) + - new class added to allow scatterplotting using ``df.plot(kind="scatter")``(:issue:`2215`) - ``HDFStore`` diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index bdeb4ca3d0212..be18f0bd5cf89 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -449,7 +449,7 @@ def test_plot_xy(self): # columns.inferred_type == 'mixed' # TODO add MultiIndex test - + @slow def test_xcompat(self): import pandas as pd @@ -534,6 +534,21 @@ def test_subplots(self): [self.assert_(label.get_visible()) for label in ax.get_yticklabels()] + @slow + def test_plot_scatter(self): + from matplotlib.pylab import close + df = DataFrame(randn(6, 4), + index=list(string.ascii_letters[:6]), + columns=['x', 'y', 'z', 'four']) + + _check_plot_works(df.plot, x='x', y='y', kind='scatter') + _check_plot_works(df.plot, x=1, y=2, kind='scatter') + + with tm.assertRaises(ValueError): + df.plot(x='x', kind='scatter') + with tm.assertRaises(ValueError): + df.plot(y='y', kind='scatter') + @slow def test_plot_bar(self): from matplotlib.pylab import close diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index d6c0482d86be4..7de5840384974 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -322,7 +322,6 @@ def _gcf(): import matplotlib.pyplot as plt return plt.gcf() - def _get_marker_compat(marker): import matplotlib.lines as mlines import matplotlib as mpl @@ -1201,7 +1200,32 @@ def _post_plot_logic(self): for ax in self.axes: ax.legend(loc='best') - +class ScatterPlot(MPLPlot): + def __init__(self, data, x, y, **kwargs): + MPLPlot.__init__(self, data, **kwargs) + self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor']) + if x is None or y is None: + raise ValueError( 'scatter requires and x and y column') + if com.is_integer(x) and not self.data.columns.holds_integer(): + x = self.data.columns[x] + if com.is_integer(y) and not self.data.columns.holds_integer(): + y = self.data.columns[y] + self.x = x + self.y = y + + + def _make_plot(self): + x, y, data = self.x, self.y, self.data + ax = self.axes[0] + ax.scatter(data[x].values, data[y].values, **self.kwds) + + def _post_plot_logic(self): + ax = self.axes[0] + x, y = self.x, self.y + ax.set_ylabel(com.pprint_thing(y)) + ax.set_xlabel(com.pprint_thing(x)) + + class LinePlot(MPLPlot): def __init__(self, data, **kwargs): @@ -1562,7 +1586,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, secondary_y=False, **kwds): """ - Make line or bar plot of DataFrame's series with the index on the x-axis + Make line, bar, or scatter plots of DataFrame series with the index on the x-axis using matplotlib / pylab. Parameters @@ -1593,10 +1617,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, ax : matplotlib axis object, default None style : list or dict matplotlib line style per column - kind : {'line', 'bar', 'barh', 'kde', 'density'} + kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter'} bar : vertical bar plot barh : horizontal bar plot kde/density : Kernel Density Estimation plot + scatter: scatter plot logx : boolean, default False For line plots, use log scaling on x axis logy : boolean, default False @@ -1632,36 +1657,50 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, klass = BarPlot elif kind == 'kde': klass = KdePlot + elif kind == 'scatter': + klass = ScatterPlot else: raise ValueError('Invalid chart type given %s' % kind) - if x is not None: - if com.is_integer(x) and not frame.columns.holds_integer(): - x = frame.columns[x] - frame = frame.set_index(x) - - if y is not None: - if com.is_integer(y) and not frame.columns.holds_integer(): - y = frame.columns[y] - label = x if x is not None else frame.index.name - label = kwds.pop('label', label) - ser = frame[y] - ser.index.name = label - return plot_series(ser, label=label, kind=kind, - use_index=use_index, - rot=rot, xticks=xticks, yticks=yticks, - xlim=xlim, ylim=ylim, ax=ax, style=style, - grid=grid, logx=logx, logy=logy, - secondary_y=secondary_y, title=title, - figsize=figsize, fontsize=fontsize, **kwds) - - plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, - legend=legend, ax=ax, style=style, fontsize=fontsize, - use_index=use_index, sharex=sharex, sharey=sharey, - xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, - title=title, grid=grid, figsize=figsize, logx=logx, - logy=logy, sort_columns=sort_columns, - secondary_y=secondary_y, **kwds) + if kind == 'scatter': + plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, + rot=rot,legend=legend, ax=ax, style=style, + fontsize=fontsize, use_index=use_index, sharex=sharex, + sharey=sharey, xticks=xticks, yticks=yticks, + xlim=xlim, ylim=ylim, title=title, grid=grid, + figsize=figsize, logx=logx, logy=logy, + sort_columns=sort_columns, secondary_y=secondary_y, + **kwds) + else: + if x is not None: + if com.is_integer(x) and not frame.columns.holds_integer(): + x = frame.columns[x] + frame = frame.set_index(x) + + if y is not None: + if com.is_integer(y) and not frame.columns.holds_integer(): + y = frame.columns[y] + label = x if x is not None else frame.index.name + label = kwds.pop('label', label) + ser = frame[y] + ser.index.name = label + return plot_series(ser, label=label, kind=kind, + use_index=use_index, + rot=rot, xticks=xticks, yticks=yticks, + xlim=xlim, ylim=ylim, ax=ax, style=style, + grid=grid, logx=logx, logy=logy, + secondary_y=secondary_y, title=title, + figsize=figsize, fontsize=fontsize, **kwds) + + else: + plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, + legend=legend, ax=ax, style=style, fontsize=fontsize, + use_index=use_index, sharex=sharex, sharey=sharey, + xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, + title=title, grid=grid, figsize=figsize, logx=logx, + logy=logy, sort_columns=sort_columns, + secondary_y=secondary_y, **kwds) + plot_obj.generate() plot_obj.draw() if subplots: