diff --git a/doc/source/user_guide/visualization.rst b/doc/source/user_guide/visualization.rst index c4ee8677a6b0d..7911c58b9867e 100644 --- a/doc/source/user_guide/visualization.rst +++ b/doc/source/user_guide/visualization.rst @@ -552,6 +552,9 @@ These can be specified by the ``x`` and ``y`` keywords. .. ipython:: python df = pd.DataFrame(np.random.rand(50, 4), columns=["a", "b", "c", "d"]) + df["species"] = pd.Categorical( + ["setosa"] * 20 + ["versicolor"] * 20 + ["virginica"] * 10 + ) @savefig scatter_plot.png df.plot.scatter(x="a", y="b"); @@ -579,6 +582,21 @@ each point: df.plot.scatter(x="a", y="b", c="c", s=50); +.. ipython:: python + :suppress: + + plt.close("all") + +If a categorical column is passed to ``c``, then a discrete colorbar will be produced: + +.. versionadded:: 1.3.0 + +.. ipython:: python + + @savefig scatter_plot_categorical.png + df.plot.scatter(x="a", y="b", c="species", cmap="viridis", s=50); + + .. ipython:: python :suppress: diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 381a05a18b278..886e034f60d27 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -52,6 +52,7 @@ Other enhancements - :meth:`DataFrame.apply` can now accept NumPy unary operators as strings, e.g. ``df.apply("sqrt")``, which was already the case for :meth:`Series.apply` (:issue:`39116`) - :meth:`DataFrame.apply` can now accept non-callable DataFrame properties as strings, e.g. ``df.apply("size")``, which was already the case for :meth:`Series.apply` (:issue:`39116`) - :meth:`Series.apply` can now accept list-like or dictionary-like arguments that aren't lists or dictionaries, e.g. ``ser.apply(np.array(["sum", "mean"]))``, which was already the case for :meth:`DataFrame.apply` (:issue:`39140`) +- :meth:`DataFrame.plot.scatter` can now accept a categorical column as the argument to ``c`` (:issue:`12380`, :issue:`31357`) - :meth:`.Styler.set_tooltips` allows on hover tooltips to be added to styled HTML dataframes. .. --------------------------------------------------------------------------- diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index fa9f030ac4bb3..7d743075674f1 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -10,6 +10,7 @@ from pandas.util._decorators import cache_readonly from pandas.core.dtypes.common import ( + is_categorical_dtype, is_extension_array_dtype, is_float, is_float_dtype, @@ -388,6 +389,10 @@ def result(self): return self.axes[0] def _convert_to_ndarray(self, data): + # GH31357: categorical columns are processed separately + if is_categorical_dtype(data): + return data + # GH32073: cast to float if values contain nulled integers if ( is_integer_dtype(data.dtype) or is_float_dtype(data.dtype) @@ -974,7 +979,7 @@ def _plot_colorbar(self, ax: Axes, **kwds): if mpl_ge_3_0_0(): # The workaround below is no longer necessary. - return + return cbar points = ax.get_position().get_points() cbar_points = cbar.ax.get_position().get_points() @@ -992,6 +997,8 @@ def _plot_colorbar(self, ax: Axes, **kwds): # print(points[1, 1] - points[0, 1]) # print(cbar_points[1, 1] - cbar_points[0, 1]) + return cbar + class ScatterPlot(PlanePlot): _kind = "scatter" @@ -1014,6 +1021,8 @@ def _make_plot(self): c_is_column = is_hashable(c) and c in self.data.columns + color_by_categorical = c_is_column and is_categorical_dtype(self.data[c]) + # pandas uses colormap, matplotlib uses cmap. cmap = self.colormap or "Greys" cmap = self.plt.cm.get_cmap(cmap) @@ -1024,11 +1033,22 @@ def _make_plot(self): c_values = self.plt.rcParams["patch.facecolor"] elif color is not None: c_values = color + elif color_by_categorical: + c_values = self.data[c].cat.codes elif c_is_column: c_values = self.data[c].values else: c_values = c + if color_by_categorical: + from matplotlib import colors + + n_cats = len(self.data[c].cat.categories) + cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)]) + bounds = np.linspace(0, n_cats, n_cats + 1) + norm = colors.BoundaryNorm(bounds, cmap.N) + else: + norm = None # plot colorbar if # 1. colormap is assigned, and # 2.`c` is a column containing only numeric values @@ -1045,11 +1065,15 @@ def _make_plot(self): c=c_values, label=label, cmap=cmap, + norm=norm, **self.kwds, ) if cb: cbar_label = c if c_is_column else "" - self._plot_colorbar(ax, label=cbar_label) + cbar = self._plot_colorbar(ax, label=cbar_label) + if color_by_categorical: + cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats)) + cbar.ax.set_yticklabels(self.data[c].cat.categories) if label is not None: self._add_legend_handle(scatter, label) diff --git a/pandas/tests/plotting/frame/test_frame.py b/pandas/tests/plotting/frame/test_frame.py index d25741a0a9fae..41df9fb2e5af0 100644 --- a/pandas/tests/plotting/frame/test_frame.py +++ b/pandas/tests/plotting/frame/test_frame.py @@ -696,6 +696,37 @@ def test_scatterplot_object_data(self): _check_plot_works(df.plot.scatter, x="a", y="b") _check_plot_works(df.plot.scatter, x=0, y=1) + @pytest.mark.parametrize("ordered", [True, False]) + @pytest.mark.parametrize( + "categories", + (["setosa", "versicolor", "virginica"], ["versicolor", "virginica", "setosa"]), + ) + def test_scatterplot_color_by_categorical(self, ordered, categories): + df = DataFrame( + [[5.1, 3.5], [4.9, 3.0], [7.0, 3.2], [6.4, 3.2], [5.9, 3.0]], + columns=["length", "width"], + ) + df["species"] = pd.Categorical( + ["setosa", "setosa", "virginica", "virginica", "versicolor"], + ordered=ordered, + categories=categories, + ) + ax = df.plot.scatter(x=0, y=1, c="species") + (colorbar_collection,) = ax.collections + colorbar = colorbar_collection.colorbar + + expected_ticks = np.array([0.5, 1.5, 2.5]) + result_ticks = colorbar.get_ticks() + tm.assert_numpy_array_equal(result_ticks, expected_ticks) + + expected_boundaries = np.array([0.0, 1.0, 2.0, 3.0]) + result_boundaries = colorbar._boundaries + tm.assert_numpy_array_equal(result_boundaries, expected_boundaries) + + expected_yticklabels = categories + result_yticklabels = [i.get_text() for i in colorbar.ax.get_ymajorticklabels()] + assert all(i == j for i, j in zip(result_yticklabels, expected_yticklabels)) + @pytest.mark.parametrize("x, y", [("x", "y"), ("y", "x"), ("y", "y")]) def test_plot_scatter_with_categorical_data(self, x, y): # after fixing GH 18755, should be able to plot categorical data