Skip to content

Commit bd68912

Browse files
committed
BUG: allow numpy.array as c values to scatterplot
Ensure that we can pass an np.array as 'c' straight through to matplotlib, this functionality was accidentally removed previously. Add tests. Closes #8852
1 parent cba2916 commit bd68912

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

doc/source/whatsnew/v0.15.2.txt

+5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ Bug Fixes
156156
and the last offset is not calculated from the start of the range (:issue:`8683`)
157157

158158

159+
160+
- Bug where DataFrame.plot(kind='scatter') fails when checking if an np.array is in the DataFrame (:issue:`8852`)
161+
162+
163+
159164
- Bug in `pd.infer_freq`/`DataFrame.inferred_freq` that prevented proper sub-daily frequency inference
160165
when the index contained DST days (:issue:`8772`).
161166
- Bug where index name was still used when plotting a series with ``use_index=False`` (:issue:`8558`).

pandas/tests/test_graphics.py

+25
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,31 @@ def test_plot_scatter_with_c(self):
16451645
self.assertIs(ax.collections[0].colorbar, None)
16461646
self._check_colors(ax.collections, facecolors=['r'])
16471647

1648+
# Ensure that we can pass an np.array straight through to matplotlib,
1649+
# this functionality was accidentally removed previously.
1650+
# See https://github.com/pydata/pandas/issues/8852 for bug report
1651+
#
1652+
# Exercise colormap path and non-colormap path as they are independent
1653+
#
1654+
df = DataFrame({'A': [1, 2], 'B': [3, 4]})
1655+
red_rgba = [1.0, 0.0, 0.0, 1.0]
1656+
green_rgba = [0.0, 1.0, 0.0, 1.0]
1657+
rgba_array = np.array([red_rgba, green_rgba])
1658+
ax = df.plot(kind='scatter', x='A', y='B', c=rgba_array)
1659+
# expect the face colors of the points in the non-colormap path to be
1660+
# identical to the values we supplied, normally we'd be on shaky ground
1661+
# comparing floats for equality but here we expect them to be
1662+
# identical.
1663+
self.assertTrue(
1664+
np.array_equal(
1665+
ax.collections[0].get_facecolor(),
1666+
rgba_array))
1667+
# we don't test the colors of the faces in this next plot because they
1668+
# are dependent on the spring colormap, which may change its colors
1669+
# later.
1670+
float_array = np.array([0.0, 1.0])
1671+
df.plot(kind='scatter', x='A', y='B', c=float_array, cmap='spring')
1672+
16481673
@slow
16491674
def test_plot_bar(self):
16501675
df = DataFrame(randn(6, 4),

pandas/tools/plotting.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1403,16 +1403,18 @@ def _make_plot(self):
14031403
x, y, c, data = self.x, self.y, self.c, self.data
14041404
ax = self.axes[0]
14051405

1406+
c_is_column = com.is_hashable(c) and c in self.data.columns
1407+
14061408
# plot a colorbar only if a colormap is provided or necessary
1407-
cb = self.kwds.pop('colorbar', self.colormap or c in self.data.columns)
1409+
cb = self.kwds.pop('colorbar', self.colormap or c_is_column)
14081410

14091411
# pandas uses colormap, matplotlib uses cmap.
14101412
cmap = self.colormap or 'Greys'
14111413
cmap = plt.cm.get_cmap(cmap)
14121414

14131415
if c is None:
14141416
c_values = self.plt.rcParams['patch.facecolor']
1415-
elif c in self.data.columns:
1417+
elif c_is_column:
14161418
c_values = self.data[c].values
14171419
else:
14181420
c_values = c
@@ -1427,7 +1429,7 @@ def _make_plot(self):
14271429
img = ax.collections[0]
14281430
kws = dict(ax=ax)
14291431
if mpl_ge_1_3_1:
1430-
kws['label'] = c if c in self.data.columns else ''
1432+
kws['label'] = c if c_is_column else ''
14311433
self.fig.colorbar(img, **kws)
14321434

14331435
self._add_legend_handle(scatter, label)

0 commit comments

Comments
 (0)