Skip to content

Commit f3e1991

Browse files
ENH: DataFrame.plot.scatter argument c now accepts a column of strings, where rows with the same string are colored identically (#59239)
1 parent 57a4fb9 commit f3e1991

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Other enhancements
5353
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
5454
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
5555
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
56+
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
5657
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
5758
- Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`)
5859
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)

pandas/plotting/_matplotlib/core.py

+41
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,22 @@ def _make_plot(self, fig: Figure) -> None:
13431343
label = self.label
13441344
else:
13451345
label = None
1346+
1347+
# if a list of non color strings is passed in as c, color points
1348+
# by uniqueness of the strings, such same strings get same color
1349+
create_colors = not self._are_valid_colors(c_values)
1350+
if create_colors:
1351+
color_mapping = self._get_color_mapping(c_values)
1352+
c_values = [color_mapping[s] for s in c_values]
1353+
1354+
# build legend for labeling custom colors
1355+
ax.legend(
1356+
handles=[
1357+
mpl.patches.Circle((0, 0), facecolor=c, label=s)
1358+
for s, c in color_mapping.items()
1359+
]
1360+
)
1361+
13461362
scatter = ax.scatter(
13471363
data[x].values,
13481364
data[y].values,
@@ -1353,6 +1369,7 @@ def _make_plot(self, fig: Figure) -> None:
13531369
s=self.s,
13541370
**self.kwds,
13551371
)
1372+
13561373
if cb:
13571374
cbar_label = c if c_is_column else ""
13581375
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
@@ -1392,6 +1409,30 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
13921409
c_values = c
13931410
return c_values
13941411

1412+
def _are_valid_colors(self, c_values: Series) -> bool:
1413+
# check if c_values contains strings and if these strings are valid mpl colors.
1414+
# no need to check numerics as these (and mpl colors) will be validated for us
1415+
# in .Axes.scatter._parse_scatter_color_args(...)
1416+
unique = np.unique(c_values)
1417+
try:
1418+
if len(c_values) and all(isinstance(c, str) for c in unique):
1419+
mpl.colors.to_rgba_array(unique)
1420+
1421+
return True
1422+
1423+
except (TypeError, ValueError) as _:
1424+
return False
1425+
1426+
def _get_color_mapping(self, c_values: Series) -> dict[str, np.ndarray]:
1427+
unique = np.unique(c_values)
1428+
n_colors = len(unique)
1429+
1430+
# passing `None` here will default to :rc:`image.cmap`
1431+
cmap = mpl.colormaps.get_cmap(self.colormap)
1432+
colors = cmap(np.linspace(0, 1, n_colors)) # RGB tuples
1433+
1434+
return dict(zip(unique, colors))
1435+
13951436
def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
13961437
c = self.c
13971438
if self.colormap is not None:

pandas/tests/plotting/frame/test_frame_color.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,53 @@ def test_scatter_with_c_column_name_with_colors(self, cmap):
217217
ax = df.plot.scatter(x=0, y=1, cmap=cmap, c="species")
218218
else:
219219
ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap)
220+
221+
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 3 # r/g/b
222+
assert (
223+
np.unique(ax.collections[0].get_facecolor(), axis=0)
224+
== np.array(
225+
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]]
226+
) # r/g/b
227+
).all()
220228
assert ax.collections[0].colorbar is None
221229

230+
def test_scatter_with_c_column_name_without_colors(self):
231+
# Given
232+
colors = ["NY", "MD", "MA", "CA"]
233+
color_count = 4 # 4 unique colors
234+
235+
# When
236+
df = DataFrame(
237+
{
238+
"dataX": range(100),
239+
"dataY": range(100),
240+
"color": (colors[i % len(colors)] for i in range(100)),
241+
}
242+
)
243+
244+
# Then
245+
ax = df.plot.scatter("dataX", "dataY", c="color")
246+
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count
247+
248+
# Given
249+
colors = ["r", "g", "not-a-color"]
250+
color_count = 3
251+
# Also, since not all are mpl-colors, points matching 'r' or 'g'
252+
# are not necessarily red or green
253+
254+
# When
255+
df = DataFrame(
256+
{
257+
"dataX": range(100),
258+
"dataY": range(100),
259+
"color": (colors[i % len(colors)] for i in range(100)),
260+
}
261+
)
262+
263+
# Then
264+
ax = df.plot.scatter("dataX", "dataY", c="color")
265+
assert len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == color_count
266+
222267
def test_scatter_colors(self):
223268
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
224269
with pytest.raises(TypeError, match="Specify exactly one of `c` and `color`"):
@@ -229,7 +274,14 @@ def test_scatter_colors_not_raising_warnings(self):
229274
# provided via 'c'. Parameters 'cmap' will be ignored
230275
df = DataFrame({"x": [1, 2, 3], "y": [1, 2, 3]})
231276
with tm.assert_produces_warning(None):
232-
df.plot.scatter(x="x", y="y", c="b")
277+
ax = df.plot.scatter(x="x", y="y", c="b")
278+
assert (
279+
len(np.unique(ax.collections[0].get_facecolor(), axis=0)) == 1
280+
) # blue
281+
assert (
282+
np.unique(ax.collections[0].get_facecolor(), axis=0)
283+
== np.array([[0.0, 0.0, 1.0, 1.0]])
284+
).all() # blue
233285

234286
def test_scatter_colors_default(self):
235287
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})

0 commit comments

Comments
 (0)