Skip to content

ENH: Scatterplot Method added #3473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ API Changes
- Default export for ``to_clipboard`` is now csv with a sep of `\t` for
compat (:issue:`3368`)
- ``at`` now will enlarge the object inplace (and return the same) (:issue:`2578`)
- new class added to allow scatterplotting using ``df.plot(kind="scatter")``(:issue:`2215`)
- ``HDFStore``

Expand Down
17 changes: 16 additions & 1 deletion pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def test_plot_xy(self):

# columns.inferred_type == 'mixed'
# TODO add MultiIndex test

@slow
def test_xcompat(self):
import pandas as pd
Expand Down Expand Up @@ -534,6 +534,21 @@ def test_subplots(self):
[self.assert_(label.get_visible())
for label in ax.get_yticklabels()]

@slow
def test_plot_scatter(self):
from matplotlib.pylab import close
df = DataFrame(randn(6, 4),
index=list(string.ascii_letters[:6]),
columns=['x', 'y', 'z', 'four'])

_check_plot_works(df.plot, x='x', y='y', kind='scatter')
_check_plot_works(df.plot, x=1, y=2, kind='scatter')

with tm.assertRaises(ValueError):
df.plot(x='x', kind='scatter')
with tm.assertRaises(ValueError):
df.plot(y='y', kind='scatter')

@slow
def test_plot_bar(self):
from matplotlib.pylab import close
Expand Down
101 changes: 70 additions & 31 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def _gcf():
import matplotlib.pyplot as plt
return plt.gcf()


def _get_marker_compat(marker):
import matplotlib.lines as mlines
import matplotlib as mpl
Expand Down Expand Up @@ -1201,7 +1200,32 @@ def _post_plot_logic(self):
for ax in self.axes:
ax.legend(loc='best')


class ScatterPlot(MPLPlot):
def __init__(self, data, x, y, **kwargs):
MPLPlot.__init__(self, data, **kwargs)
self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor'])
if x is None or y is None:
raise ValueError( 'scatter requires and x and y column')
if com.is_integer(x) and not self.data.columns.holds_integer():
x = self.data.columns[x]
if com.is_integer(y) and not self.data.columns.holds_integer():
y = self.data.columns[y]
self.x = x
self.y = y


def _make_plot(self):
x, y, data = self.x, self.y, self.data
ax = self.axes[0]
ax.scatter(data[x].values, data[y].values, **self.kwds)

def _post_plot_logic(self):
ax = self.axes[0]
x, y = self.x, self.y
ax.set_ylabel(com.pprint_thing(y))
ax.set_xlabel(com.pprint_thing(x))


class LinePlot(MPLPlot):

def __init__(self, data, **kwargs):
Expand Down Expand Up @@ -1562,7 +1586,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
secondary_y=False, **kwds):

"""
Make line or bar plot of DataFrame's series with the index on the x-axis
Make line, bar, or scatter plots of DataFrame series with the index on the x-axis
using matplotlib / pylab.
Parameters
Expand Down Expand Up @@ -1593,10 +1617,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
ax : matplotlib axis object, default None
style : list or dict
matplotlib line style per column
kind : {'line', 'bar', 'barh', 'kde', 'density'}
kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter'}
bar : vertical bar plot
barh : horizontal bar plot
kde/density : Kernel Density Estimation plot
scatter: scatter plot
logx : boolean, default False
For line plots, use log scaling on x axis
logy : boolean, default False
Expand Down Expand Up @@ -1632,36 +1657,50 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
klass = BarPlot
elif kind == 'kde':
klass = KdePlot
elif kind == 'scatter':
klass = ScatterPlot
else:
raise ValueError('Invalid chart type given %s' % kind)

if x is not None:
if com.is_integer(x) and not frame.columns.holds_integer():
x = frame.columns[x]
frame = frame.set_index(x)

if y is not None:
if com.is_integer(y) and not frame.columns.holds_integer():
y = frame.columns[y]
label = x if x is not None else frame.index.name
label = kwds.pop('label', label)
ser = frame[y]
ser.index.name = label
return plot_series(ser, label=label, kind=kind,
use_index=use_index,
rot=rot, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, ax=ax, style=style,
grid=grid, logx=logx, logy=logy,
secondary_y=secondary_y, title=title,
figsize=figsize, fontsize=fontsize, **kwds)

plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
legend=legend, ax=ax, style=style, fontsize=fontsize,
use_index=use_index, sharex=sharex, sharey=sharey,
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
title=title, grid=grid, figsize=figsize, logx=logx,
logy=logy, sort_columns=sort_columns,
secondary_y=secondary_y, **kwds)
if kind == 'scatter':
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
rot=rot,legend=legend, ax=ax, style=style,
fontsize=fontsize, use_index=use_index, sharex=sharex,
sharey=sharey, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, title=title, grid=grid,
figsize=figsize, logx=logx, logy=logy,
sort_columns=sort_columns, secondary_y=secondary_y,
**kwds)
else:
if x is not None:
if com.is_integer(x) and not frame.columns.holds_integer():
x = frame.columns[x]
frame = frame.set_index(x)

if y is not None:
if com.is_integer(y) and not frame.columns.holds_integer():
y = frame.columns[y]
label = x if x is not None else frame.index.name
label = kwds.pop('label', label)
ser = frame[y]
ser.index.name = label
return plot_series(ser, label=label, kind=kind,
use_index=use_index,
rot=rot, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, ax=ax, style=style,
grid=grid, logx=logx, logy=logy,
secondary_y=secondary_y, title=title,
figsize=figsize, fontsize=fontsize, **kwds)

else:
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
legend=legend, ax=ax, style=style, fontsize=fontsize,
use_index=use_index, sharex=sharex, sharey=sharey,
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
title=title, grid=grid, figsize=figsize, logx=logx,
logy=logy, sort_columns=sort_columns,
secondary_y=secondary_y, **kwds)

plot_obj.generate()
plot_obj.draw()
if subplots:
Expand Down