Skip to content

Commit 1357114

Browse files
ENH/VIZ: Allowing s parameter of scatter plots to be a column name (#33107)
1 parent cbd1103 commit 1357114

File tree

4 files changed

+15
-1
lines changed

4 files changed

+15
-1
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ Other
471471
- Fixed bug in :func:`pandas.testing.assert_series_equal` where dtypes were checked for ``Interval`` and ``ExtensionArray`` operands when ``check_dtype`` was ``False`` (:issue:`32747`)
472472
- Bug in :meth:`Series.map` not raising on invalid ``na_action`` (:issue:`32815`)
473473
- Bug in :meth:`DataFrame.__dir__` caused a segfault when using unicode surrogates in a column name (:issue:`25509`)
474+
- Bug in :meth:`DataFrame.plot.scatter` caused an error when plotting variable marker sizes (:issue:`32904`)
474475

475476
.. ---------------------------------------------------------------------------
476477

pandas/plotting/_core.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1468,15 +1468,19 @@ def scatter(self, x, y, s=None, c=None, **kwargs):
14681468
y : int or str
14691469
The column name or column position to be used as vertical
14701470
coordinates for each point.
1471-
s : scalar or array_like, optional
1471+
s : str, scalar or array_like, optional
14721472
The size of each point. Possible values are:
14731473
1474+
- A string with the name of the column to be used for marker's size.
1475+
14741476
- A single scalar so all points have the same size.
14751477
14761478
- A sequence of scalars, which will be used for each point's size
14771479
recursively. For instance, when passing [2,14] all points size
14781480
will be either 2 or 14, alternatively.
14791481
1482+
.. versionchanged:: 1.1.0
1483+
14801484
c : str, int or array_like, optional
14811485
The color of each point. Possible values are:
14821486

pandas/plotting/_matplotlib/core.py

+2
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,8 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs):
934934
# hide the matplotlib default for size, in case we want to change
935935
# the handling of this argument later
936936
s = 20
937+
elif is_hashable(s) and s in data.columns:
938+
s = data[s]
937939
super().__init__(data, x, y, s=s, **kwargs)
938940
if is_integer(c) and not self.data.columns.holds_integer():
939941
c = self.data.columns[c]

pandas/tests/plotting/test_frame.py

+7
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,13 @@ def test_plot_scatter_with_c(self):
13061306
float_array = np.array([0.0, 1.0])
13071307
df.plot.scatter(x="A", y="B", c=float_array, cmap="spring")
13081308

1309+
def test_plot_scatter_with_s(self):
1310+
# this refers to GH 32904
1311+
df = DataFrame(np.random.random((10, 3)) * 100, columns=["a", "b", "c"],)
1312+
1313+
ax = df.plot.scatter(x="a", y="b", s="c")
1314+
tm.assert_numpy_array_equal(df["c"].values, right=ax.collections[0].get_sizes())
1315+
13091316
def test_scatter_colors(self):
13101317
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
13111318
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)