@@ -1370,32 +1370,55 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
1370
1370
class ScatterPlot (MPLPlot ):
1371
1371
_layout_type = 'single'
1372
1372
1373
- def __init__ (self , data , x , y , ** kwargs ):
1373
+ def __init__ (self , data , x , y , c = None , ** kwargs ):
1374
1374
MPLPlot .__init__ (self , data , ** kwargs )
1375
- self .kwds .setdefault ('c' , self .plt .rcParams ['patch.facecolor' ])
1376
1375
if x is None or y is None :
1377
1376
raise ValueError ( 'scatter requires and x and y column' )
1378
1377
if com .is_integer (x ) and not self .data .columns .holds_integer ():
1379
1378
x = self .data .columns [x ]
1380
1379
if com .is_integer (y ) and not self .data .columns .holds_integer ():
1381
1380
y = self .data .columns [y ]
1381
+ if com .is_integer (c ) and not self .data .columns .holds_integer ():
1382
+ c = self .data .columns [c ]
1382
1383
self .x = x
1383
1384
self .y = y
1385
+ self .c = c
1384
1386
1385
1387
@property
1386
1388
def nseries (self ):
1387
1389
return 1
1388
1390
1389
1391
def _make_plot (self ):
1390
- x , y , data = self .x , self .y , self .data
1392
+ import matplotlib .pyplot as plt
1393
+
1394
+ x , y , c , data = self .x , self .y , self .c , self .data
1391
1395
ax = self .axes [0 ]
1392
1396
1397
+ # plot a colorbar only if a colormap is provided or necessary
1398
+ cb = self .kwds .pop ('colorbar' , self .colormap or c in self .data .columns )
1399
+
1400
+ # pandas uses colormap, matplotlib uses cmap.
1401
+ cmap = self .colormap or 'RdBu'
1402
+ cmap = plt .cm .get_cmap (cmap )
1403
+
1404
+ if c is None :
1405
+ c_values = self .plt .rcParams ['patch.facecolor' ]
1406
+ elif c in self .data .columns :
1407
+ c_values = self .data [c ].values
1408
+ else :
1409
+ c_values = c
1410
+
1393
1411
if self .legend and hasattr (self , 'label' ):
1394
1412
label = self .label
1395
1413
else :
1396
1414
label = None
1397
- scatter = ax .scatter (data [x ].values , data [y ].values , label = label ,
1398
- ** self .kwds )
1415
+ scatter = ax .scatter (data [x ].values , data [y ].values , c = c_values ,
1416
+ label = label , cmap = cmap , ** self .kwds )
1417
+ if cb :
1418
+ img = ax .collections [0 ]
1419
+ cb_label = c if c in self .data .columns else ''
1420
+ self .fig .colorbar (img , ax = ax , label = cb_label )
1421
+
1399
1422
self ._add_legend_handle (scatter , label )
1400
1423
1401
1424
errors_x = self ._get_errorbars (label = x , index = 0 , yerr = False )
@@ -2261,6 +2284,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
2261
2284
colormap : str or matplotlib colormap object, default None
2262
2285
Colormap to select colors from. If string, load colormap with that name
2263
2286
from matplotlib.
2287
+ colorbar : boolean, optional
2288
+ If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
2264
2289
position : float
2265
2290
Specify relative alignments for bar plot layout.
2266
2291
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
@@ -2287,6 +2312,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
2287
2312
`C` specifies the value at each `(x, y)` point and `reduce_C_function`
2288
2313
is a function of one argument that reduces all the values in a bin to
2289
2314
a single number (e.g. `mean`, `max`, `sum`, `std`).
2315
+
2316
+ If `kind`='scatter' and the argument `c` is the name of a dataframe column,
2317
+ the values of that column are used to color each point.
2290
2318
"""
2291
2319
2292
2320
kind = _get_standard_kind (kind .lower ().strip ())
0 commit comments