Skip to content

Commit 67e380a

Browse files
committed
API: support c and colormap args for DataFrame.plot with kind='scatter'
`matplotlib.pyplot.scatter` supports the argument `c` for setting the color of each point. This patch lets you easily set it by giving a column name (currently you need to supply an ndarray to make it work, since pandas doesn't use it): df.plot('x', 'y', c='z', kind='scatter') vs df.plot('x', 'y', c=df['z'].values, kind='scatter') While I was at it, I noticed that `kind='scatter'` did not support the `colormap` argument that some of the other methods support (notably `kind='hexbin'`). So I added it, too. This change should be almost entirely backwards compatible, unless folks are naming columns in their data frame valid matplotlib colors and using the same color name for the `c` argument. A colorbar will also be added automatically if relevant.
1 parent 7800290 commit 67e380a

File tree

4 files changed

+71
-5
lines changed

4 files changed

+71
-5
lines changed

doc/source/v0.15.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ Enhancements
435435

436436
- Added ``layout`` keyword to ``DataFrame.plot`` (:issue:`6667`)
437437
- Allow to pass multiple axes to ``DataFrame.plot``, ``hist`` and ``boxplot`` (:issue:`5353`, :issue:`6970`, :issue:`7069`)
438+
- Added support for ``c``, ``colormap`` and ``colorbar`` arguments for
439+
``DataFrame.plot`` with ``kind='scatter'`` (:issue:`7780`)
438440

439441

440442
- ``PeriodIndex`` supports ``resolution`` as the same as ``DatetimeIndex`` (:issue:`7708`)

doc/source/visualization.rst

+8
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,14 @@ It is recommended to specify ``color`` and ``label`` keywords to distinguish eac
521521
df.plot(kind='scatter', x='c', y='d',
522522
color='DarkGreen', label='Group 2', ax=ax);
523523
524+
The keyword ``c`` may be given as the name of a column to provide colors for
525+
each point:
526+
527+
.. ipython:: python
528+
529+
@savefig scatter_plot_colored.png
530+
df.plot(kind='scatter', x='a', y='b', c='c', s=50);
531+
524532
You can pass other keywords supported by matplotlib ``scatter``.
525533
Below example shows a bubble chart using a dataframe column values as bubble size.
526534

pandas/tests/test_graphics.py

+28
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,34 @@ def test_plot_scatter(self):
14971497
axes = df.plot(x='x', y='y', kind='scatter', subplots=True)
14981498
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
14991499

1500+
@slow
1501+
def test_plot_scatter_with_c(self):
1502+
df = DataFrame(randn(6, 4),
1503+
index=list(string.ascii_letters[:6]),
1504+
columns=['x', 'y', 'z', 'four'])
1505+
1506+
axes = [df.plot(kind='scatter', x='x', y='y', c='z'),
1507+
df.plot(kind='scatter', x=0, y=1, c=2)]
1508+
for ax in axes:
1509+
# default to RdBu
1510+
self.assertEqual(ax.collections[0].cmap.name, 'RdBu')
1511+
# n.b. there appears to be no public method to get the colorbar
1512+
# label
1513+
self.assertEqual(ax.collections[0].colorbar._label, 'z')
1514+
1515+
cm = 'cubehelix'
1516+
ax = df.plot(kind='scatter', x='x', y='y', c='z', colormap=cm)
1517+
self.assertEqual(ax.collections[0].cmap.name, cm)
1518+
1519+
# verify turning off colorbar works
1520+
ax = df.plot(kind='scatter', x='x', y='y', c='z', colorbar=False)
1521+
self.assertIs(ax.collections[0].colorbar, None)
1522+
1523+
# verify that we can still plot a solid color
1524+
ax = df.plot(x=0, y=1, c='red', kind='scatter')
1525+
self.assertIs(ax.collections[0].colorbar, None)
1526+
self._check_colors(ax.collections, facecolors=['r'])
1527+
15001528
@slow
15011529
def test_plot_bar(self):
15021530
df = DataFrame(randn(6, 4),

pandas/tools/plotting.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -1368,32 +1368,55 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
13681368
class ScatterPlot(MPLPlot):
13691369
_layout_type = 'single'
13701370

1371-
def __init__(self, data, x, y, **kwargs):
1371+
def __init__(self, data, x, y, c=None, **kwargs):
13721372
MPLPlot.__init__(self, data, **kwargs)
1373-
self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor'])
13741373
if x is None or y is None:
13751374
raise ValueError( 'scatter requires and x and y column')
13761375
if com.is_integer(x) and not self.data.columns.holds_integer():
13771376
x = self.data.columns[x]
13781377
if com.is_integer(y) and not self.data.columns.holds_integer():
13791378
y = self.data.columns[y]
1379+
if com.is_integer(c) and not self.data.columns.holds_integer():
1380+
c = self.data.columns[c]
13801381
self.x = x
13811382
self.y = y
1383+
self.c = c
13821384

13831385
@property
13841386
def nseries(self):
13851387
return 1
13861388

13871389
def _make_plot(self):
1388-
x, y, data = self.x, self.y, self.data
1390+
import matplotlib.pyplot as plt
1391+
1392+
x, y, c, data = self.x, self.y, self.c, self.data
13891393
ax = self.axes[0]
13901394

1395+
# plot a colorbar only if a colormap is provided or necessary
1396+
cb = self.kwds.pop('colorbar', self.colormap or c in self.data.columns)
1397+
1398+
# pandas uses colormap, matplotlib uses cmap.
1399+
cmap = self.colormap or 'RdBu'
1400+
cmap = plt.cm.get_cmap(cmap)
1401+
1402+
if c is None:
1403+
c_values = self.plt.rcParams['patch.facecolor']
1404+
elif c in self.data.columns:
1405+
c_values = self.data[c].values
1406+
else:
1407+
c_values = c
1408+
13911409
if self.legend and hasattr(self, 'label'):
13921410
label = self.label
13931411
else:
13941412
label = None
1395-
scatter = ax.scatter(data[x].values, data[y].values, label=label,
1396-
**self.kwds)
1413+
scatter = ax.scatter(data[x].values, data[y].values, c=c_values,
1414+
label=label, cmap=cmap, **self.kwds)
1415+
if cb:
1416+
img = ax.collections[0]
1417+
cb_label = c if c in self.data.columns else ''
1418+
self.fig.colorbar(img, ax=ax, label=cb_label)
1419+
13971420
self._add_legend_handle(scatter, label)
13981421

13991422
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
@@ -2259,6 +2282,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
22592282
colormap : str or matplotlib colormap object, default None
22602283
Colormap to select colors from. If string, load colormap with that name
22612284
from matplotlib.
2285+
colorbar : boolean, optional
2286+
If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
22622287
position : float
22632288
Specify relative alignments for bar plot layout.
22642289
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
@@ -2285,6 +2310,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
22852310
`C` specifies the value at each `(x, y)` point and `reduce_C_function`
22862311
is a function of one argument that reduces all the values in a bin to
22872312
a single number (e.g. `mean`, `max`, `sum`, `std`).
2313+
2314+
If `kind`='scatter' and the argument `c` is the name of a dataframe column,
2315+
the values of that column are used to color each point.
22882316
"""
22892317

22902318
kind = _get_standard_kind(kind.lower().strip())

0 commit comments

Comments
 (0)