@@ -322,7 +322,6 @@ def _gcf():
322
322
import matplotlib .pyplot as plt
323
323
return plt .gcf ()
324
324
325
-
326
325
def _get_marker_compat (marker ):
327
326
import matplotlib .lines as mlines
328
327
import matplotlib as mpl
@@ -1201,7 +1200,32 @@ def _post_plot_logic(self):
1201
1200
for ax in self .axes :
1202
1201
ax .legend (loc = 'best' )
1203
1202
1204
-
1203
+ class ScatterPlot (MPLPlot ):
1204
+ def __init__ (self , data , x , y , ** kwargs ):
1205
+ MPLPlot .__init__ (self , data , ** kwargs )
1206
+ self .kwds .setdefault ('c' , self .plt .rcParams ['patch.facecolor' ])
1207
+ if x is None or y is None :
1208
+ raise ValueError ( 'scatter requires and x and y column' )
1209
+ if com .is_integer (x ) and not self .data .columns .holds_integer ():
1210
+ x = self .data .columns [x ]
1211
+ if com .is_integer (y ) and not self .data .columns .holds_integer ():
1212
+ y = self .data .columns [y ]
1213
+ self .x = x
1214
+ self .y = y
1215
+
1216
+
1217
+ def _make_plot (self ):
1218
+ x , y , data = self .x , self .y , self .data
1219
+ ax = self .axes [0 ]
1220
+ ax .scatter (data [x ].values , data [y ].values , ** self .kwds )
1221
+
1222
+ def _post_plot_logic (self ):
1223
+ ax = self .axes [0 ]
1224
+ x , y = self .x , self .y
1225
+ ax .set_ylabel (com .pprint_thing (y ))
1226
+ ax .set_xlabel (com .pprint_thing (x ))
1227
+
1228
+
1205
1229
class LinePlot (MPLPlot ):
1206
1230
1207
1231
def __init__ (self , data , ** kwargs ):
@@ -1562,7 +1586,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1562
1586
secondary_y = False , ** kwds ):
1563
1587
1564
1588
"""
1565
- Make line or bar plot of DataFrame's series with the index on the x-axis
1589
+ Make line, bar, or scatter plots of DataFrame series with the index on the x-axis
1566
1590
using matplotlib / pylab.
1567
1591
1568
1592
Parameters
@@ -1593,10 +1617,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1593
1617
ax : matplotlib axis object, default None
1594
1618
style : list or dict
1595
1619
matplotlib line style per column
1596
- kind : {'line', 'bar', 'barh', 'kde', 'density'}
1620
+ kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter' }
1597
1621
bar : vertical bar plot
1598
1622
barh : horizontal bar plot
1599
1623
kde/density : Kernel Density Estimation plot
1624
+ scatter: scatter plot
1600
1625
logx : boolean, default False
1601
1626
For line plots, use log scaling on x axis
1602
1627
logy : boolean, default False
@@ -1632,36 +1657,50 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1632
1657
klass = BarPlot
1633
1658
elif kind == 'kde' :
1634
1659
klass = KdePlot
1660
+ elif kind == 'scatter' :
1661
+ klass = ScatterPlot
1635
1662
else :
1636
1663
raise ValueError ('Invalid chart type given %s' % kind )
1637
1664
1638
- if x is not None :
1639
- if com .is_integer (x ) and not frame .columns .holds_integer ():
1640
- x = frame .columns [x ]
1641
- frame = frame .set_index (x )
1642
-
1643
- if y is not None :
1644
- if com .is_integer (y ) and not frame .columns .holds_integer ():
1645
- y = frame .columns [y ]
1646
- label = x if x is not None else frame .index .name
1647
- label = kwds .pop ('label' , label )
1648
- ser = frame [y ]
1649
- ser .index .name = label
1650
- return plot_series (ser , label = label , kind = kind ,
1651
- use_index = use_index ,
1652
- rot = rot , xticks = xticks , yticks = yticks ,
1653
- xlim = xlim , ylim = ylim , ax = ax , style = style ,
1654
- grid = grid , logx = logx , logy = logy ,
1655
- secondary_y = secondary_y , title = title ,
1656
- figsize = figsize , fontsize = fontsize , ** kwds )
1657
-
1658
- plot_obj = klass (frame , kind = kind , subplots = subplots , rot = rot ,
1659
- legend = legend , ax = ax , style = style , fontsize = fontsize ,
1660
- use_index = use_index , sharex = sharex , sharey = sharey ,
1661
- xticks = xticks , yticks = yticks , xlim = xlim , ylim = ylim ,
1662
- title = title , grid = grid , figsize = figsize , logx = logx ,
1663
- logy = logy , sort_columns = sort_columns ,
1664
- secondary_y = secondary_y , ** kwds )
1665
+ if kind == 'scatter' :
1666
+ plot_obj = klass (frame , x = x , y = y , kind = kind , subplots = subplots ,
1667
+ rot = rot ,legend = legend , ax = ax , style = style ,
1668
+ fontsize = fontsize , use_index = use_index , sharex = sharex ,
1669
+ sharey = sharey , xticks = xticks , yticks = yticks ,
1670
+ xlim = xlim , ylim = ylim , title = title , grid = grid ,
1671
+ figsize = figsize , logx = logx , logy = logy ,
1672
+ sort_columns = sort_columns , secondary_y = secondary_y ,
1673
+ ** kwds )
1674
+ else :
1675
+ if x is not None :
1676
+ if com .is_integer (x ) and not frame .columns .holds_integer ():
1677
+ x = frame .columns [x ]
1678
+ frame = frame .set_index (x )
1679
+
1680
+ if y is not None :
1681
+ if com .is_integer (y ) and not frame .columns .holds_integer ():
1682
+ y = frame .columns [y ]
1683
+ label = x if x is not None else frame .index .name
1684
+ label = kwds .pop ('label' , label )
1685
+ ser = frame [y ]
1686
+ ser .index .name = label
1687
+ return plot_series (ser , label = label , kind = kind ,
1688
+ use_index = use_index ,
1689
+ rot = rot , xticks = xticks , yticks = yticks ,
1690
+ xlim = xlim , ylim = ylim , ax = ax , style = style ,
1691
+ grid = grid , logx = logx , logy = logy ,
1692
+ secondary_y = secondary_y , title = title ,
1693
+ figsize = figsize , fontsize = fontsize , ** kwds )
1694
+
1695
+ else :
1696
+ plot_obj = klass (frame , kind = kind , subplots = subplots , rot = rot ,
1697
+ legend = legend , ax = ax , style = style , fontsize = fontsize ,
1698
+ use_index = use_index , sharex = sharex , sharey = sharey ,
1699
+ xticks = xticks , yticks = yticks , xlim = xlim , ylim = ylim ,
1700
+ title = title , grid = grid , figsize = figsize , logx = logx ,
1701
+ logy = logy , sort_columns = sort_columns ,
1702
+ secondary_y = secondary_y , ** kwds )
1703
+
1665
1704
plot_obj .generate ()
1666
1705
plot_obj .draw ()
1667
1706
if subplots :
0 commit comments