diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 20415bba99476..aeb0eeab0b3ba 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -452,6 +452,7 @@ Other - Fixed bug in :func:`pandas.testing.assert_series_equal` where dtypes were checked for ``Interval`` and ``ExtensionArray`` operands when ``check_dtype`` was ``False`` (:issue:`32747`) - Bug in :meth:`Series.map` not raising on invalid ``na_action`` (:issue:`32815`) - Bug in :meth:`DataFrame.__dir__` caused a segfault when using unicode surrogates in a column name (:issue:`25509`) +- Bug in :meth:`DataFrame.plot.scatter` caused an error when plotting variable marker sizes (:issue:`32904`) .. --------------------------------------------------------------------------- diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py index d3db539084609..e466a215091ea 100644 --- a/pandas/plotting/_core.py +++ b/pandas/plotting/_core.py @@ -1468,15 +1468,19 @@ def scatter(self, x, y, s=None, c=None, **kwargs): y : int or str The column name or column position to be used as vertical coordinates for each point. - s : scalar or array_like, optional + s : str, scalar or array_like, optional The size of each point. Possible values are: + - A string with the name of the column to be used for marker's size. + - A single scalar so all points have the same size. - A sequence of scalars, which will be used for each point's size recursively. For instance, when passing [2,14] all points size will be either 2 or 14, alternatively. + .. versionchanged:: 1.1.0 + c : str, int or array_like, optional The color of each point. Possible values are: diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 63d0b8abe59d9..bc8346fd48433 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -934,6 +934,8 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs): # hide the matplotlib default for size, in case we want to change # the handling of this argument later s = 20 + elif is_hashable(s) and s in data.columns: + s = data[s] super().__init__(data, x, y, s=s, **kwargs) if is_integer(c) and not self.data.columns.holds_integer(): c = self.data.columns[c] diff --git a/pandas/tests/plotting/test_frame.py b/pandas/tests/plotting/test_frame.py index 45ac18b2661c3..08b33ee547a48 100644 --- a/pandas/tests/plotting/test_frame.py +++ b/pandas/tests/plotting/test_frame.py @@ -1306,6 +1306,13 @@ def test_plot_scatter_with_c(self): float_array = np.array([0.0, 1.0]) df.plot.scatter(x="A", y="B", c=float_array, cmap="spring") + def test_plot_scatter_with_s(self): + # this refers to GH 32904 + df = DataFrame(np.random.random((10, 3)) * 100, columns=["a", "b", "c"],) + + ax = df.plot.scatter(x="a", y="b", s="c") + tm.assert_numpy_array_equal(df["c"].values, right=ax.collections[0].get_sizes()) + def test_scatter_colors(self): df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) with pytest.raises(TypeError):