Skip to content

ENH/VIZ: Allowing s parameter of scatter plots to be a column name #33107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ Other
- Fixed bug in :func:`pandas.testing.assert_series_equal` where dtypes were checked for ``Interval`` and ``ExtensionArray`` operands when ``check_dtype`` was ``False`` (:issue:`32747`)
- Bug in :meth:`Series.map` not raising on invalid ``na_action`` (:issue:`32815`)
- Bug in :meth:`DataFrame.__dir__` caused a segfault when using unicode surrogates in a column name (:issue:`25509`)
- Bug in :meth:`DataFrame.plot.scatter` caused an error when plotting variable marker sizes (:issue:`32904`)

.. ---------------------------------------------------------------------------

Expand Down
6 changes: 5 additions & 1 deletion pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,15 +1468,19 @@ def scatter(self, x, y, s=None, c=None, **kwargs):
y : int or str
The column name or column position to be used as vertical
coordinates for each point.
s : scalar or array_like, optional
s : str, scalar or array_like, optional
The size of each point. Possible values are:

- A string with the name of the column to be used for marker's size.

- A single scalar so all points have the same size.

- A sequence of scalars, which will be used for each point's size
recursively. For instance, when passing [2,14] all points size
will be either 2 or 14, alternatively.

.. versionchanged:: 1.1.0

c : str, int or array_like, optional
The color of each point. Possible values are:

Expand Down
2 changes: 2 additions & 0 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,8 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs):
# hide the matplotlib default for size, in case we want to change
# the handling of this argument later
s = 20
elif is_hashable(s) and s in data.columns:
s = data[s]
super().__init__(data, x, y, s=s, **kwargs)
if is_integer(c) and not self.data.columns.holds_integer():
c = self.data.columns[c]
Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/plotting/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,13 @@ def test_plot_scatter_with_c(self):
float_array = np.array([0.0, 1.0])
df.plot.scatter(x="A", y="B", c=float_array, cmap="spring")

def test_plot_scatter_with_s(self):
# this refers to GH 32904
df = DataFrame(np.random.random((10, 3)) * 100, columns=["a", "b", "c"],)

ax = df.plot.scatter(x="a", y="b", s="c")
tm.assert_numpy_array_equal(df["c"].values, right=ax.collections[0].get_sizes())

def test_scatter_colors(self):
df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
with pytest.raises(TypeError):
Expand Down