diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 44429b11856eb..1c809b2b77a6d 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -421,7 +421,7 @@ Plotting - Bug in :meth:`DataFrame.plot.box` that prevented labeling the x-axis (:issue:`45463`) - Bug in :meth:`DataFrame.boxplot` that prevented passing in ``xlabel`` and ``ylabel`` (:issue:`45463`) - Bug in :meth:`DataFrame.boxplot` that prevented specifying ``vert=False`` (:issue:`36918`) -- +- Bug in :meth:`DataFrame.scatter` that prevented specifying ``norm`` (:issue:`45809`) Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 41bd1df81ef61..48875114794d9 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -1082,7 +1082,7 @@ def _make_plot(self): bounds = np.linspace(0, n_cats, n_cats + 1) norm = colors.BoundaryNorm(bounds, cmap.N) else: - norm = None + norm = self.kwds.pop("norm", None) # plot colorbar if # 1. colormap is assigned, and # 2.`c` is a column containing only numeric values diff --git a/pandas/tests/plotting/frame/test_frame.py b/pandas/tests/plotting/frame/test_frame.py index 818c86dfca424..62540a15f47bd 100644 --- a/pandas/tests/plotting/frame/test_frame.py +++ b/pandas/tests/plotting/frame/test_frame.py @@ -787,6 +787,27 @@ def test_plot_scatter_with_s(self): ax = df.plot.scatter(x="a", y="b", s="c") tm.assert_numpy_array_equal(df["c"].values, right=ax.collections[0].get_sizes()) + def test_plot_scatter_with_norm(self): + # added while fixing GH 45809 + import matplotlib as mpl + + df = DataFrame(np.random.random((10, 3)) * 100, columns=["a", "b", "c"]) + norm = mpl.colors.LogNorm() + ax = df.plot.scatter(x="a", y="b", c="c", norm=norm) + assert ax.collections[0].norm is norm + + def test_plot_scatter_without_norm(self): + # added while fixing GH 45809 + import matplotlib as mpl + + df = DataFrame(np.random.random((10, 3)) * 100, columns=["a", "b", "c"]) + ax = df.plot.scatter(x="a", y="b", c="c") + plot_norm = ax.collections[0].norm + color_min_max = (df.c.min(), df.c.max()) + default_norm = mpl.colors.Normalize(*color_min_max) + for value in df.c: + assert plot_norm(value) == default_norm(value) + @pytest.mark.slow def test_plot_bar(self): df = DataFrame(