Skip to content

Commit dfd92c7

Browse files
author
zach powers
committed
add ScatterPlot class to allow df.plot(kind=scatter)
1 parent eba6001 commit dfd92c7

File tree

3 files changed

+87
-32
lines changed

3 files changed

+87
-32
lines changed

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ API Changes
210210
- Default export for ``to_clipboard`` is now csv with a sep of `\t` for
211211
compat (:issue:`3368`)
212212
- ``at`` now will enlarge the object inplace (and return the same) (:issue:`2578`)
213+
- new class added to allow scatterplotting using ``df.plot(kind="scatter")``(:issue:`2215`)
213214
214215
- ``HDFStore``
215216

pandas/tests/test_graphics.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def test_plot_xy(self):
449449

450450
# columns.inferred_type == 'mixed'
451451
# TODO add MultiIndex test
452-
452+
453453
@slow
454454
def test_xcompat(self):
455455
import pandas as pd
@@ -534,6 +534,21 @@ def test_subplots(self):
534534
[self.assert_(label.get_visible())
535535
for label in ax.get_yticklabels()]
536536

537+
@slow
538+
def test_plot_scatter(self):
539+
from matplotlib.pylab import close
540+
df = DataFrame(randn(6, 4),
541+
index=list(string.ascii_letters[:6]),
542+
columns=['x', 'y', 'z', 'four'])
543+
544+
_check_plot_works(df.plot, x='x', y='y', kind='scatter')
545+
_check_plot_works(df.plot, x=1, y=2, kind='scatter')
546+
547+
with tm.assertRaises(ValueError):
548+
df.plot(x='x', kind='scatter')
549+
with tm.assertRaises(ValueError):
550+
df.plot(y='y', kind='scatter')
551+
537552
@slow
538553
def test_plot_bar(self):
539554
from matplotlib.pylab import close

pandas/tools/plotting.py

+70-31
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def _gcf():
322322
import matplotlib.pyplot as plt
323323
return plt.gcf()
324324

325-
326325
def _get_marker_compat(marker):
327326
import matplotlib.lines as mlines
328327
import matplotlib as mpl
@@ -1201,7 +1200,32 @@ def _post_plot_logic(self):
12011200
for ax in self.axes:
12021201
ax.legend(loc='best')
12031202

1204-
1203+
class ScatterPlot(MPLPlot):
1204+
def __init__(self, data, x, y, **kwargs):
1205+
MPLPlot.__init__(self, data, **kwargs)
1206+
self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor'])
1207+
if x is None or y is None:
1208+
raise ValueError( 'scatter requires and x and y column')
1209+
if com.is_integer(x) and not self.data.columns.holds_integer():
1210+
x = self.data.columns[x]
1211+
if com.is_integer(y) and not self.data.columns.holds_integer():
1212+
y = self.data.columns[y]
1213+
self.x = x
1214+
self.y = y
1215+
1216+
1217+
def _make_plot(self):
1218+
x, y, data = self.x, self.y, self.data
1219+
ax = self.axes[0]
1220+
ax.scatter(data[x].values, data[y].values, **self.kwds)
1221+
1222+
def _post_plot_logic(self):
1223+
ax = self.axes[0]
1224+
x, y = self.x, self.y
1225+
ax.set_ylabel(com.pprint_thing(y))
1226+
ax.set_xlabel(com.pprint_thing(x))
1227+
1228+
12051229
class LinePlot(MPLPlot):
12061230

12071231
def __init__(self, data, **kwargs):
@@ -1562,7 +1586,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
15621586
secondary_y=False, **kwds):
15631587

15641588
"""
1565-
Make line or bar plot of DataFrame's series with the index on the x-axis
1589+
Make line, bar, or scatter plots of DataFrame series with the index on the x-axis
15661590
using matplotlib / pylab.
15671591
15681592
Parameters
@@ -1593,10 +1617,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
15931617
ax : matplotlib axis object, default None
15941618
style : list or dict
15951619
matplotlib line style per column
1596-
kind : {'line', 'bar', 'barh', 'kde', 'density'}
1620+
kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter'}
15971621
bar : vertical bar plot
15981622
barh : horizontal bar plot
15991623
kde/density : Kernel Density Estimation plot
1624+
scatter: scatter plot
16001625
logx : boolean, default False
16011626
For line plots, use log scaling on x axis
16021627
logy : boolean, default False
@@ -1632,36 +1657,50 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
16321657
klass = BarPlot
16331658
elif kind == 'kde':
16341659
klass = KdePlot
1660+
elif kind == 'scatter':
1661+
klass = ScatterPlot
16351662
else:
16361663
raise ValueError('Invalid chart type given %s' % kind)
16371664

1638-
if x is not None:
1639-
if com.is_integer(x) and not frame.columns.holds_integer():
1640-
x = frame.columns[x]
1641-
frame = frame.set_index(x)
1642-
1643-
if y is not None:
1644-
if com.is_integer(y) and not frame.columns.holds_integer():
1645-
y = frame.columns[y]
1646-
label = x if x is not None else frame.index.name
1647-
label = kwds.pop('label', label)
1648-
ser = frame[y]
1649-
ser.index.name = label
1650-
return plot_series(ser, label=label, kind=kind,
1651-
use_index=use_index,
1652-
rot=rot, xticks=xticks, yticks=yticks,
1653-
xlim=xlim, ylim=ylim, ax=ax, style=style,
1654-
grid=grid, logx=logx, logy=logy,
1655-
secondary_y=secondary_y, title=title,
1656-
figsize=figsize, fontsize=fontsize, **kwds)
1657-
1658-
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
1659-
legend=legend, ax=ax, style=style, fontsize=fontsize,
1660-
use_index=use_index, sharex=sharex, sharey=sharey,
1661-
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
1662-
title=title, grid=grid, figsize=figsize, logx=logx,
1663-
logy=logy, sort_columns=sort_columns,
1664-
secondary_y=secondary_y, **kwds)
1665+
if kind == 'scatter':
1666+
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
1667+
rot=rot,legend=legend, ax=ax, style=style,
1668+
fontsize=fontsize, use_index=use_index, sharex=sharex,
1669+
sharey=sharey, xticks=xticks, yticks=yticks,
1670+
xlim=xlim, ylim=ylim, title=title, grid=grid,
1671+
figsize=figsize, logx=logx, logy=logy,
1672+
sort_columns=sort_columns, secondary_y=secondary_y,
1673+
**kwds)
1674+
else:
1675+
if x is not None:
1676+
if com.is_integer(x) and not frame.columns.holds_integer():
1677+
x = frame.columns[x]
1678+
frame = frame.set_index(x)
1679+
1680+
if y is not None:
1681+
if com.is_integer(y) and not frame.columns.holds_integer():
1682+
y = frame.columns[y]
1683+
label = x if x is not None else frame.index.name
1684+
label = kwds.pop('label', label)
1685+
ser = frame[y]
1686+
ser.index.name = label
1687+
return plot_series(ser, label=label, kind=kind,
1688+
use_index=use_index,
1689+
rot=rot, xticks=xticks, yticks=yticks,
1690+
xlim=xlim, ylim=ylim, ax=ax, style=style,
1691+
grid=grid, logx=logx, logy=logy,
1692+
secondary_y=secondary_y, title=title,
1693+
figsize=figsize, fontsize=fontsize, **kwds)
1694+
1695+
else:
1696+
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
1697+
legend=legend, ax=ax, style=style, fontsize=fontsize,
1698+
use_index=use_index, sharex=sharex, sharey=sharey,
1699+
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
1700+
title=title, grid=grid, figsize=figsize, logx=logx,
1701+
logy=logy, sort_columns=sort_columns,
1702+
secondary_y=secondary_y, **kwds)
1703+
16651704
plot_obj.generate()
16661705
plot_obj.draw()
16671706
if subplots:

0 commit comments

Comments
 (0)