Skip to content

Commit 1cad74e

Browse files
authored
Fix scatter norm keyword (#45966)
* TST: add test for norm scatter plot parameter (#45809) * BUG: don't duplicate norm parameter for scatter plots (#45809) * TST: Add github issue numbers (GH45809) * TST: remove dependence on private attributes (#45809) * DOC: add entry to visualization bug fixes (#45809) * TST: reduce number of norm comparisons (#45809) * TST: simplify test (#45809)
1 parent 221c3aa commit 1cad74e

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

doc/source/whatsnew/v1.5.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ Plotting
421421
- Bug in :meth:`DataFrame.plot.box` that prevented labeling the x-axis (:issue:`45463`)
422422
- Bug in :meth:`DataFrame.boxplot` that prevented passing in ``xlabel`` and ``ylabel`` (:issue:`45463`)
423423
- Bug in :meth:`DataFrame.boxplot` that prevented specifying ``vert=False`` (:issue:`36918`)
424-
-
424+
- Bug in :meth:`DataFrame.scatter` that prevented specifying ``norm`` (:issue:`45809`)
425425

426426
Groupby/resample/rolling
427427
^^^^^^^^^^^^^^^^^^^^^^^^

pandas/plotting/_matplotlib/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,7 @@ def _make_plot(self):
10821082
bounds = np.linspace(0, n_cats, n_cats + 1)
10831083
norm = colors.BoundaryNorm(bounds, cmap.N)
10841084
else:
1085-
norm = None
1085+
norm = self.kwds.pop("norm", None)
10861086
# plot colorbar if
10871087
# 1. colormap is assigned, and
10881088
# 2.`c` is a column containing only numeric values

pandas/tests/plotting/frame/test_frame.py

+21
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,27 @@ def test_plot_scatter_with_s(self):
787787
ax = df.plot.scatter(x="a", y="b", s="c")
788788
tm.assert_numpy_array_equal(df["c"].values, right=ax.collections[0].get_sizes())
789789

790+
def test_plot_scatter_with_norm(self):
791+
# added while fixing GH 45809
792+
import matplotlib as mpl
793+
794+
df = DataFrame(np.random.random((10, 3)) * 100, columns=["a", "b", "c"])
795+
norm = mpl.colors.LogNorm()
796+
ax = df.plot.scatter(x="a", y="b", c="c", norm=norm)
797+
assert ax.collections[0].norm is norm
798+
799+
def test_plot_scatter_without_norm(self):
800+
# added while fixing GH 45809
801+
import matplotlib as mpl
802+
803+
df = DataFrame(np.random.random((10, 3)) * 100, columns=["a", "b", "c"])
804+
ax = df.plot.scatter(x="a", y="b", c="c")
805+
plot_norm = ax.collections[0].norm
806+
color_min_max = (df.c.min(), df.c.max())
807+
default_norm = mpl.colors.Normalize(*color_min_max)
808+
for value in df.c:
809+
assert plot_norm(value) == default_norm(value)
810+
790811
@pytest.mark.slow
791812
def test_plot_bar(self):
792813
df = DataFrame(

0 commit comments

Comments
 (0)