Skip to content

Commit 42ed2a2

Browse files
author
Tom Augspurger
committed
Merge pull request #7780 from shoyer/plot-frame-scatter
API: support `c` and `colormap` args for DataFrame.plot with kind='scatter'
2 parents 1a2885f + 67e380a commit 42ed2a2

File tree

4 files changed

+71
-5
lines changed

4 files changed

+71
-5
lines changed

doc/source/v0.15.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,8 @@ Enhancements
491491

492492
- Added ``layout`` keyword to ``DataFrame.plot`` (:issue:`6667`)
493493
- Allow to pass multiple axes to ``DataFrame.plot``, ``hist`` and ``boxplot`` (:issue:`5353`, :issue:`6970`, :issue:`7069`)
494+
- Added support for ``c``, ``colormap`` and ``colorbar`` arguments for
495+
``DataFrame.plot`` with ``kind='scatter'`` (:issue:`7780`)
494496

495497

496498
- ``PeriodIndex`` supports ``resolution`` as the same as ``DatetimeIndex`` (:issue:`7708`)

doc/source/visualization.rst

+8
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,14 @@ It is recommended to specify ``color`` and ``label`` keywords to distinguish eac
521521
df.plot(kind='scatter', x='c', y='d',
522522
color='DarkGreen', label='Group 2', ax=ax);
523523
524+
The keyword ``c`` may be given as the name of a column to provide colors for
525+
each point:
526+
527+
.. ipython:: python
528+
529+
@savefig scatter_plot_colored.png
530+
df.plot(kind='scatter', x='a', y='b', c='c', s=50);
531+
524532
You can pass other keywords supported by matplotlib ``scatter``.
525533
Below example shows a bubble chart using a dataframe column values as bubble size.
526534

pandas/tests/test_graphics.py

+28
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,34 @@ def test_plot_scatter(self):
15141514
axes = df.plot(x='x', y='y', kind='scatter', subplots=True)
15151515
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
15161516

1517+
@slow
1518+
def test_plot_scatter_with_c(self):
1519+
df = DataFrame(randn(6, 4),
1520+
index=list(string.ascii_letters[:6]),
1521+
columns=['x', 'y', 'z', 'four'])
1522+
1523+
axes = [df.plot(kind='scatter', x='x', y='y', c='z'),
1524+
df.plot(kind='scatter', x=0, y=1, c=2)]
1525+
for ax in axes:
1526+
# default to RdBu
1527+
self.assertEqual(ax.collections[0].cmap.name, 'RdBu')
1528+
# n.b. there appears to be no public method to get the colorbar
1529+
# label
1530+
self.assertEqual(ax.collections[0].colorbar._label, 'z')
1531+
1532+
cm = 'cubehelix'
1533+
ax = df.plot(kind='scatter', x='x', y='y', c='z', colormap=cm)
1534+
self.assertEqual(ax.collections[0].cmap.name, cm)
1535+
1536+
# verify turning off colorbar works
1537+
ax = df.plot(kind='scatter', x='x', y='y', c='z', colorbar=False)
1538+
self.assertIs(ax.collections[0].colorbar, None)
1539+
1540+
# verify that we can still plot a solid color
1541+
ax = df.plot(x=0, y=1, c='red', kind='scatter')
1542+
self.assertIs(ax.collections[0].colorbar, None)
1543+
self._check_colors(ax.collections, facecolors=['r'])
1544+
15171545
@slow
15181546
def test_plot_bar(self):
15191547
df = DataFrame(randn(6, 4),

pandas/tools/plotting.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -1370,32 +1370,55 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
13701370
class ScatterPlot(MPLPlot):
13711371
_layout_type = 'single'
13721372

1373-
def __init__(self, data, x, y, **kwargs):
1373+
def __init__(self, data, x, y, c=None, **kwargs):
13741374
MPLPlot.__init__(self, data, **kwargs)
1375-
self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor'])
13761375
if x is None or y is None:
13771376
raise ValueError( 'scatter requires and x and y column')
13781377
if com.is_integer(x) and not self.data.columns.holds_integer():
13791378
x = self.data.columns[x]
13801379
if com.is_integer(y) and not self.data.columns.holds_integer():
13811380
y = self.data.columns[y]
1381+
if com.is_integer(c) and not self.data.columns.holds_integer():
1382+
c = self.data.columns[c]
13821383
self.x = x
13831384
self.y = y
1385+
self.c = c
13841386

13851387
@property
13861388
def nseries(self):
13871389
return 1
13881390

13891391
def _make_plot(self):
1390-
x, y, data = self.x, self.y, self.data
1392+
import matplotlib.pyplot as plt
1393+
1394+
x, y, c, data = self.x, self.y, self.c, self.data
13911395
ax = self.axes[0]
13921396

1397+
# plot a colorbar only if a colormap is provided or necessary
1398+
cb = self.kwds.pop('colorbar', self.colormap or c in self.data.columns)
1399+
1400+
# pandas uses colormap, matplotlib uses cmap.
1401+
cmap = self.colormap or 'RdBu'
1402+
cmap = plt.cm.get_cmap(cmap)
1403+
1404+
if c is None:
1405+
c_values = self.plt.rcParams['patch.facecolor']
1406+
elif c in self.data.columns:
1407+
c_values = self.data[c].values
1408+
else:
1409+
c_values = c
1410+
13931411
if self.legend and hasattr(self, 'label'):
13941412
label = self.label
13951413
else:
13961414
label = None
1397-
scatter = ax.scatter(data[x].values, data[y].values, label=label,
1398-
**self.kwds)
1415+
scatter = ax.scatter(data[x].values, data[y].values, c=c_values,
1416+
label=label, cmap=cmap, **self.kwds)
1417+
if cb:
1418+
img = ax.collections[0]
1419+
cb_label = c if c in self.data.columns else ''
1420+
self.fig.colorbar(img, ax=ax, label=cb_label)
1421+
13991422
self._add_legend_handle(scatter, label)
14001423

14011424
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
@@ -2261,6 +2284,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
22612284
colormap : str or matplotlib colormap object, default None
22622285
Colormap to select colors from. If string, load colormap with that name
22632286
from matplotlib.
2287+
colorbar : boolean, optional
2288+
If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
22642289
position : float
22652290
Specify relative alignments for bar plot layout.
22662291
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
@@ -2287,6 +2312,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
22872312
`C` specifies the value at each `(x, y)` point and `reduce_C_function`
22882313
is a function of one argument that reduces all the values in a bin to
22892314
a single number (e.g. `mean`, `max`, `sum`, `std`).
2315+
2316+
If `kind`='scatter' and the argument `c` is the name of a dataframe column,
2317+
the values of that column are used to color each point.
22902318
"""
22912319

22922320
kind = _get_standard_kind(kind.lower().strip())

0 commit comments

Comments
 (0)