From d8d8c3c5619ca5b2ce2dc37cf0f0e081ab8414d0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 15:49:34 +0100 Subject: [PATCH 01/19] add legend with colors if coloring by categorical --- pandas/plotting/_matplotlib/core.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index a049ac99f0e08..2d5c6ef1d089a 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -8,6 +8,7 @@ from pandas.util._decorators import cache_readonly from pandas.core.dtypes.common import ( + is_categorical_dtype, is_hashable, is_integer, is_iterator, @@ -413,7 +414,8 @@ def _compute_plot_data(self): # np.ndarray before plot. numeric_data = numeric_data.copy() for col in numeric_data: - numeric_data[col] = np.asarray(numeric_data[col]) + if not is_categorical_dtype(numeric_data[col]): + numeric_data[col] = np.asarray(numeric_data[col]) self.data = numeric_data @@ -965,7 +967,11 @@ def _make_plot(self): elif color is not None: c_values = color elif c_is_column: - c_values = self.data[c].values + if not is_categorical_dtype(self.data[c]): + c_values = self.data[c].values + + else: + c_values = self.data[c].cat.codes else: c_values = c @@ -983,7 +989,20 @@ def _make_plot(self): ) if cb: cbar_label = c if c_is_column else "" - self._plot_colorbar(ax, label=cbar_label) + if not is_categorical_dtype(self.data[c]): + self._plot_colorbar(ax, label=cbar_label) + else: + codes = (self.data[c].cat.codes).unique() + handles = [ + self.plt.scatter( + [], + [], + color=scatter.cmap(scatter.norm(i)), + label=self.data[c].cat.categories[i], + ) + for i in codes + ] + ax.legend(handles=handles, title=cbar_label) if label is not None: self._add_legend_handle(scatter, label) From a1673136d98312f91c44e6c615acb9fd038a1b1b Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 15:50:53 +0100 Subject: [PATCH 02/19] add legend with colors if coloring by categorical --- pandas/plotting/_matplotlib/core.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 2d5c6ef1d089a..bfd72376dd0ba 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -993,13 +993,10 @@ def _make_plot(self): self._plot_colorbar(ax, label=cbar_label) else: codes = (self.data[c].cat.codes).unique() + colors = [scatter.cmap(scatter.norm(i)) for i in codes] + labels = [self.data[c].cat.categories[i] for i in codes] handles = [ - self.plt.scatter( - [], - [], - color=scatter.cmap(scatter.norm(i)), - label=self.data[c].cat.categories[i], - ) + self.plt.scatter([], [], color=colors[i], label=labels[i],) for i in codes ] ax.legend(handles=handles, title=cbar_label) From d10c5c12e51f7a61e343aab4ce46810750fbc47a Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 15:51:14 +0100 Subject: [PATCH 03/19] add legend with colors if coloring by categorical --- pandas/plotting/_matplotlib/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index bfd72376dd0ba..8170063fcfc12 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -996,7 +996,7 @@ def _make_plot(self): colors = [scatter.cmap(scatter.norm(i)) for i in codes] labels = [self.data[c].cat.categories[i] for i in codes] handles = [ - self.plt.scatter([], [], color=colors[i], label=labels[i],) + self.plt.scatter([], [], color=colors[i], label=labels[i]) for i in codes ] ax.legend(handles=handles, title=cbar_label) From 7294aa2e90663efe21e110e5e39894cd7aec443f Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 15:53:12 +0100 Subject: [PATCH 04/19] add legend with colors if coloring by categorical --- pandas/plotting/_matplotlib/core.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 8170063fcfc12..684e853137a12 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -993,12 +993,13 @@ def _make_plot(self): self._plot_colorbar(ax, label=cbar_label) else: codes = (self.data[c].cat.codes).unique() - colors = [scatter.cmap(scatter.norm(i)) for i in codes] - labels = [self.data[c].cat.categories[i] for i in codes] - handles = [ - self.plt.scatter([], [], color=colors[i], label=labels[i]) - for i in codes - ] + empty_plot = lambda i: self.plt.scatter( + [], + [], + color=scatter.cmap(scatter.norm(i)), + label=self.data[c].cat.categories[i], + ) + handles = [empty_plot(i) for i in codes] ax.legend(handles=handles, title=cbar_label) if label is not None: From 50cd05f65370c60e6c0ba23880eafa6e7ca613e7 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 15:53:40 +0100 Subject: [PATCH 05/19] add legend with colors if coloring by categorical --- pandas/plotting/_matplotlib/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 684e853137a12..5b6c875ca4d35 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -992,14 +992,13 @@ def _make_plot(self): if not is_categorical_dtype(self.data[c]): self._plot_colorbar(ax, label=cbar_label) else: - codes = (self.data[c].cat.codes).unique() empty_plot = lambda i: self.plt.scatter( [], [], color=scatter.cmap(scatter.norm(i)), label=self.data[c].cat.categories[i], ) - handles = [empty_plot(i) for i in codes] + handles = [empty_plot(i) for i in self.data[c].cat.codes.unique()] ax.legend(handles=handles, title=cbar_label) if label is not None: From 4579142391bde091db732342f2fc7cfd42263759 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 15:54:24 +0100 Subject: [PATCH 06/19] add legend with colors if coloring by categorical --- pandas/plotting/_matplotlib/core.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 5b6c875ca4d35..dba82c403c79c 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -992,13 +992,15 @@ def _make_plot(self): if not is_categorical_dtype(self.data[c]): self._plot_colorbar(ax, label=cbar_label) else: - empty_plot = lambda i: self.plt.scatter( - [], - [], - color=scatter.cmap(scatter.norm(i)), - label=self.data[c].cat.categories[i], - ) - handles = [empty_plot(i) for i in self.data[c].cat.codes.unique()] + handles = [ + self.plt.scatter( + [], + [], + color=scatter.cmap(scatter.norm(i)), + label=self.data[c].cat.categories[i], + ) + for i in self.data[c].cat.codes.unique() + ] ax.legend(handles=handles, title=cbar_label) if label is not None: From 6d3fe9e36c7bc5844f963a3c1bf1b522ec8786a8 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 16:14:00 +0100 Subject: [PATCH 07/19] add test --- pandas/tests/plotting/test_frame.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pandas/tests/plotting/test_frame.py b/pandas/tests/plotting/test_frame.py index c84a09f21f46b..1b40413b58ef0 100644 --- a/pandas/tests/plotting/test_frame.py +++ b/pandas/tests/plotting/test_frame.py @@ -1190,6 +1190,16 @@ def test_scatterplot_object_data(self): _check_plot_works(df.plot.scatter, x="a", y="b") _check_plot_works(df.plot.scatter, x=0, y=1) + def test_scatterplot_color_by_categorical(self): + df = pd.DataFrame( + [[5.1, 3.5], [4.9, 3.0], [7.0, 3.2], [6.4, 3.2], [5.9, 3.0]], + columns=["length", "width"], + ) + df["species"] = pd.Categorical( + ["setosa", "setosa", "virginica", "virginica", "versicolor"] + ) + _check_plot_works(df.plot.scatter, x=0, y=1, c="species") + @pytest.mark.slow def test_if_scatterplot_colorbar_affects_xaxis_visibility(self): # addressing issue #10611, to ensure colobar does not From 3846804120886849c6dcff6a2712304502f1223b Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Thu, 21 May 2020 16:16:47 +0100 Subject: [PATCH 08/19] revert empty line --- pandas/plotting/_matplotlib/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index dba82c403c79c..7f22188009dbc 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -969,7 +969,6 @@ def _make_plot(self): elif c_is_column: if not is_categorical_dtype(self.data[c]): c_values = self.data[c].values - else: c_values = self.data[c].cat.codes else: From ef4b03deabef2aa6f129b3e97e4094b31b3c21c2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Fri, 22 May 2020 11:43:57 +0100 Subject: [PATCH 09/19] discrete colorbar in case of ordered categorical --- pandas/plotting/_matplotlib/core.py | 71 ++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 7f22188009dbc..dedc5226591e6 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -913,7 +913,7 @@ def _plot_colorbar(self, ax, **kwds): if _mpl_ge_3_0_0(): # The workaround below is no longer necessary. - return + return cbar points = ax.get_position().get_points() cbar_points = cbar.ax.get_position().get_points() @@ -931,6 +931,8 @@ def _plot_colorbar(self, ax, **kwds): # print(points[1, 1] - points[0, 1]) # print(cbar_points[1, 1] - cbar_points[0, 1]) + return cbar + class ScatterPlot(PlanePlot): _kind = "scatter" @@ -953,6 +955,12 @@ def _make_plot(self): c_is_column = is_hashable(c) and c in self.data.columns + color_by_categorical = c_is_column and is_categorical_dtype(self.data[c]) + color_by_ordered_categorical = color_by_categorical and self.data[c].cat.ordered + color_by_unordered_categorical = ( + color_by_categorical and not self.data[c].cat.ordered + ) + # plot a colorbar only if a colormap is provided or necessary cb = self.kwds.pop("colorbar", self.colormap or c_is_column) @@ -967,40 +975,71 @@ def _make_plot(self): elif color is not None: c_values = color elif c_is_column: - if not is_categorical_dtype(self.data[c]): + if not color_by_categorical: c_values = self.data[c].values else: c_values = self.data[c].cat.codes + else: c_values = c + if color_by_categorical: + from matplotlib import colors + + n_cats = len(self.data[c].cat.categories) + cmap = colors.ListedColormap([cmap(i) for i in range((cmap.N))]) + bounds = np.linspace(0, n_cats, n_cats + 1) + ticks = np.linspace(0.5, n_cats - 0.5, n_cats) + norm = colors.BoundaryNorm(bounds, cmap.N) + if self.legend and hasattr(self, "label"): label = self.label else: label = None - scatter = ax.scatter( - data[x].values, - data[y].values, - c=c_values, - label=label, - cmap=cmap, - **self.kwds, - ) + if color_by_categorical: + scatter = ax.scatter( + data[x].values, + data[y].values, + c=c_values, + label=label, + cmap=cmap, + norm=norm, + **self.kwds, + ) + else: + scatter = ax.scatter( + data[x].values, + data[y].values, + c=c_values, + label=label, + cmap=cmap, + **self.kwds, + ) if cb: cbar_label = c if c_is_column else "" - if not is_categorical_dtype(self.data[c]): - self._plot_colorbar(ax, label=cbar_label) - else: + if color_by_unordered_categorical: handles = [ self.plt.scatter( - [], - [], + *([], []), color=scatter.cmap(scatter.norm(i)), label=self.data[c].cat.categories[i], + norm=norm, ) - for i in self.data[c].cat.codes.unique() + for i in range(n_cats) ] ax.legend(handles=handles, title=cbar_label) + elif color_by_ordered_categorical: + cbar = self._plot_colorbar( + ax, + label=cbar_label, + cmap=cmap, + boundaries=bounds, + ticks=ticks, + norm=norm, + ) + cbar.ax.set_yticklabels(self.data[c].cat.categories) + else: + self._plot_colorbar(ax, label=cbar_label) if label is not None: self._add_legend_handle(scatter, label) From cf218f7d4d26db3e8ab50b5d1dacf1c2e782266a Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Fri, 22 May 2020 12:06:04 +0100 Subject: [PATCH 10/19] cleanup --- pandas/plotting/_matplotlib/core.py | 47 +++++++++++------------------ 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index dedc5226591e6..8356edcbf5ab0 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -991,30 +991,24 @@ def _make_plot(self): bounds = np.linspace(0, n_cats, n_cats + 1) ticks = np.linspace(0.5, n_cats - 0.5, n_cats) norm = colors.BoundaryNorm(bounds, cmap.N) + yticklabels = self.data[c].cat.categories + else: + ticks = None + norm = None if self.legend and hasattr(self, "label"): label = self.label else: label = None - if color_by_categorical: - scatter = ax.scatter( - data[x].values, - data[y].values, - c=c_values, - label=label, - cmap=cmap, - norm=norm, - **self.kwds, - ) - else: - scatter = ax.scatter( - data[x].values, - data[y].values, - c=c_values, - label=label, - cmap=cmap, - **self.kwds, - ) + scatter = ax.scatter( + data[x].values, + data[y].values, + c=c_values, + label=label, + cmap=cmap, + norm=norm, + **self.kwds, + ) if cb: cbar_label = c if c_is_column else "" if color_by_unordered_categorical: @@ -1028,18 +1022,11 @@ def _make_plot(self): for i in range(n_cats) ] ax.legend(handles=handles, title=cbar_label) - elif color_by_ordered_categorical: - cbar = self._plot_colorbar( - ax, - label=cbar_label, - cmap=cmap, - boundaries=bounds, - ticks=ticks, - norm=norm, - ) - cbar.ax.set_yticklabels(self.data[c].cat.categories) else: - self._plot_colorbar(ax, label=cbar_label) + cbar = self._plot_colorbar(ax, label=cbar_label,) + if color_by_ordered_categorical: + cbar.set_ticks(ticks) + cbar.ax.set_yticklabels(yticklabels) if label is not None: self._add_legend_handle(scatter, label) From 7aae1648efe4c3dd43d92afbf5a6e07d8c62de3b Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Fri, 22 May 2020 12:06:20 +0100 Subject: [PATCH 11/19] cleanup --- pandas/plotting/_matplotlib/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 8356edcbf5ab0..bc054ada08b1e 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -1023,7 +1023,7 @@ def _make_plot(self): ] ax.legend(handles=handles, title=cbar_label) else: - cbar = self._plot_colorbar(ax, label=cbar_label,) + cbar = self._plot_colorbar(ax, label=cbar_label) if color_by_ordered_categorical: cbar.set_ticks(ticks) cbar.ax.set_yticklabels(yticklabels) From 0ab903d12ca1bb905a35da6cba333d1d7a4d1609 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Fri, 22 May 2020 13:49:42 +0100 Subject: [PATCH 12/19] plot colorbar in both cases --- pandas/plotting/_matplotlib/core.py | 34 ++++++----------------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index bc054ada08b1e..9c29e8d336f57 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -956,10 +956,6 @@ def _make_plot(self): c_is_column = is_hashable(c) and c in self.data.columns color_by_categorical = c_is_column and is_categorical_dtype(self.data[c]) - color_by_ordered_categorical = color_by_categorical and self.data[c].cat.ordered - color_by_unordered_categorical = ( - color_by_categorical and not self.data[c].cat.ordered - ) # plot a colorbar only if a colormap is provided or necessary cb = self.kwds.pop("colorbar", self.colormap or c_is_column) @@ -975,11 +971,10 @@ def _make_plot(self): elif color is not None: c_values = color elif c_is_column: - if not color_by_categorical: - c_values = self.data[c].values - else: + if color_by_categorical: c_values = self.data[c].cat.codes - + else: + c_values = self.data[c].values else: c_values = c @@ -989,11 +984,8 @@ def _make_plot(self): n_cats = len(self.data[c].cat.categories) cmap = colors.ListedColormap([cmap(i) for i in range((cmap.N))]) bounds = np.linspace(0, n_cats, n_cats + 1) - ticks = np.linspace(0.5, n_cats - 0.5, n_cats) norm = colors.BoundaryNorm(bounds, cmap.N) - yticklabels = self.data[c].cat.categories else: - ticks = None norm = None if self.legend and hasattr(self, "label"): @@ -1011,22 +1003,10 @@ def _make_plot(self): ) if cb: cbar_label = c if c_is_column else "" - if color_by_unordered_categorical: - handles = [ - self.plt.scatter( - *([], []), - color=scatter.cmap(scatter.norm(i)), - label=self.data[c].cat.categories[i], - norm=norm, - ) - for i in range(n_cats) - ] - ax.legend(handles=handles, title=cbar_label) - else: - cbar = self._plot_colorbar(ax, label=cbar_label) - if color_by_ordered_categorical: - cbar.set_ticks(ticks) - cbar.ax.set_yticklabels(yticklabels) + cbar = self._plot_colorbar(ax, label=cbar_label) + if color_by_categorical: + cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats)) + cbar.ax.set_yticklabels(self.data[c].cat.categories) if label is not None: self._add_legend_handle(scatter, label) From b93ddf1fefa28ed36b8531ecc9540ad50446eada Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Fri, 22 May 2020 14:08:49 +0100 Subject: [PATCH 13/19] update test --- pandas/tests/plotting/test_frame.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/pandas/tests/plotting/test_frame.py b/pandas/tests/plotting/test_frame.py index 1b40413b58ef0..c48e738aa422e 100644 --- a/pandas/tests/plotting/test_frame.py +++ b/pandas/tests/plotting/test_frame.py @@ -1190,15 +1190,38 @@ def test_scatterplot_object_data(self): _check_plot_works(df.plot.scatter, x="a", y="b") _check_plot_works(df.plot.scatter, x=0, y=1) - def test_scatterplot_color_by_categorical(self): + @pytest.mark.parametrize("ordered", [True, False]) + @pytest.mark.parametrize( + "categories", + (["setosa", "versicolor", "virginica"], ["versicolor", "virginica", "setosa"]), + ) + def test_scatterplot_color_by_categorical(self, ordered, categories): df = pd.DataFrame( [[5.1, 3.5], [4.9, 3.0], [7.0, 3.2], [6.4, 3.2], [5.9, 3.0]], columns=["length", "width"], ) df["species"] = pd.Categorical( - ["setosa", "setosa", "virginica", "virginica", "versicolor"] + ["setosa", "setosa", "virginica", "virginica", "versicolor"], + ordered=ordered, + categories=categories, + ) + ax = df.plot.scatter(x=0, y=1, c="species") + colorbar = ax.collections[0].colorbar + + expected_ticks = np.array([0.5, 1.5, 2.5]) + result_ticks = colorbar.get_ticks() + tm.assert_numpy_array_equal(result_ticks, expected_ticks) + + expected_boundaries = np.array([0.0, 1.0, 2.0, 3.0]) + result_boundaries = colorbar._boundaries + tm.assert_numpy_array_equal(result_boundaries, expected_boundaries) + + expected_yticklabels = categories + result_yticklabels = [i.get_text() for i in colorbar.ax.get_ymajorticklabels()] + assert all( + result == expected + for result, expected in zip(result_yticklabels, expected_yticklabels) ) - _check_plot_works(df.plot.scatter, x=0, y=1, c="species") @pytest.mark.slow def test_if_scatterplot_colorbar_affects_xaxis_visibility(self): From 5fbc117c1c942c1c77d4299f081dfc99ff4be5cf Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Fri, 22 May 2020 14:09:35 +0100 Subject: [PATCH 14/19] update test --- pandas/tests/plotting/test_frame.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pandas/tests/plotting/test_frame.py b/pandas/tests/plotting/test_frame.py index c48e738aa422e..71fd000c58845 100644 --- a/pandas/tests/plotting/test_frame.py +++ b/pandas/tests/plotting/test_frame.py @@ -1218,10 +1218,7 @@ def test_scatterplot_color_by_categorical(self, ordered, categories): expected_yticklabels = categories result_yticklabels = [i.get_text() for i in colorbar.ax.get_ymajorticklabels()] - assert all( - result == expected - for result, expected in zip(result_yticklabels, expected_yticklabels) - ) + assert all(i == j for i, j in zip(result_yticklabels, expected_yticklabels)) @pytest.mark.slow def test_if_scatterplot_colorbar_affects_xaxis_visibility(self): From b2a8b28ecd795767371b2b6e32c2985f6a674f42 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sun, 3 Jan 2021 15:58:41 +0000 Subject: [PATCH 15/19] :art: --- pandas/plotting/_matplotlib/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 2891148886459..04094d1f11df5 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -440,6 +440,7 @@ def _compute_plot_data(self): # no non-numeric frames or series allowed if is_empty: raise TypeError("no numeric data to plot") + self.data = numeric_data.apply( lambda x: self._convert_to_ndarray(x) if not is_categorical_dtype(x) else x ) From b0a8cfa5e8d00c62156074893d08a67b8ed5f6e6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sun, 17 Jan 2021 10:22:37 +0000 Subject: [PATCH 16/19] simplify logic --- pandas/plotting/_matplotlib/core.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 08292c6af9b13..6ad0e7b8b62d5 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -387,6 +387,10 @@ def result(self): return self.axes[0] def _convert_to_ndarray(self, data): + # GH31357: categorical columns are processed separately + if is_categorical_dtype(data): + return data + # GH32073: cast to float if values contain nulled integers if ( is_integer_dtype(data.dtype) or is_float_dtype(data.dtype) @@ -440,9 +444,7 @@ def _compute_plot_data(self): if is_empty: raise TypeError("no numeric data to plot") - self.data = numeric_data.apply( - lambda x: self._convert_to_ndarray(x) if not is_categorical_dtype(x) else x - ) + self.data = numeric_data.apply(self._convert_to_ndarray) def _make_plot(self): raise AbstractMethodError(self) @@ -1029,11 +1031,10 @@ def _make_plot(self): c_values = self.plt.rcParams["patch.facecolor"] elif color is not None: c_values = color + elif color_by_categorical: + c_values = self.data[c].cat.codes elif c_is_column: - if color_by_categorical: - c_values = self.data[c].cat.codes - else: - c_values = self.data[c].values + c_values = self.data[c].values else: c_values = c From 4c15a83c0aa6bb3971e4663524595e10780fc069 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sun, 17 Jan 2021 11:24:37 +0000 Subject: [PATCH 17/19] whatsnew entry --- doc/source/whatsnew/v1.3.0.rst | 1 + pandas/tests/plotting/frame/test_frame.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index ab00b749d5725..60bf15658a9f4 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -52,6 +52,7 @@ Other enhancements - :meth:`DataFrame.apply` can now accept NumPy unary operators as strings, e.g. ``df.apply("sqrt")``, which was already the case for :meth:`Series.apply` (:issue:`39116`) - :meth:`DataFrame.apply` can now accept non-callable DataFrame properties as strings, e.g. ``df.apply("size")``, which was already the case for :meth:`Series.apply` (:issue:`39116`) - :meth:`Series.apply` can now accept list-like or dictionary-like arguments that aren't lists or dictionaries, e.g. ``ser.apply(np.array(["sum", "mean"]))``, which was already the case for :meth:`DataFrame.apply` (:issue:`39140`) +- :meth:`DataFrame.plot.scatter` can now accept a categorical column as the argument to ``c`` (:issue:`31357`) .. --------------------------------------------------------------------------- diff --git a/pandas/tests/plotting/frame/test_frame.py b/pandas/tests/plotting/frame/test_frame.py index f433b8f39e862..12d70e7e82604 100644 --- a/pandas/tests/plotting/frame/test_frame.py +++ b/pandas/tests/plotting/frame/test_frame.py @@ -712,7 +712,8 @@ def test_scatterplot_color_by_categorical(self, ordered, categories): categories=categories, ) ax = df.plot.scatter(x=0, y=1, c="species") - colorbar = ax.collections[0].colorbar + (colorbar_collection,) = ax.collections + colorbar = colorbar_collection.colorbar expected_ticks = np.array([0.5, 1.5, 2.5]) result_ticks = colorbar.get_ticks() From cce3461977168b9099f2a2ac460f3eb67af53e10 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Tue, 26 Jan 2021 21:11:48 +0000 Subject: [PATCH 18/19] add example to visualisation --- doc/source/user_guide/visualization.rst | 16 ++++++++++++++++ doc/source/whatsnew/v1.3.0.rst | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/doc/source/user_guide/visualization.rst b/doc/source/user_guide/visualization.rst index c4ee8677a6b0d..29ab73473bd6d 100644 --- a/doc/source/user_guide/visualization.rst +++ b/doc/source/user_guide/visualization.rst @@ -552,6 +552,9 @@ These can be specified by the ``x`` and ``y`` keywords. .. ipython:: python df = pd.DataFrame(np.random.rand(50, 4), columns=["a", "b", "c", "d"]) + df["species"] = pd.Categorical( + ["setosa"] * 20 + ["versicolor"] * 20 + ["virginica"] * 10 + ) @savefig scatter_plot.png df.plot.scatter(x="a", y="b"); @@ -579,6 +582,19 @@ each point: df.plot.scatter(x="a", y="b", c="c", s=50); +.. ipython:: python + :suppress: + + plt.close("all") + +If a categorical column is passed to ``c``, then a discrete colorbar will be produced: + +.. ipython:: python + + @savefig scatter_plot_categorical.png + df.plot.scatter(x="a", y="b", c="species", cmap="viridis", s=50); + + .. ipython:: python :suppress: diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index a2804b9d49f9a..886e034f60d27 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -52,7 +52,7 @@ Other enhancements - :meth:`DataFrame.apply` can now accept NumPy unary operators as strings, e.g. ``df.apply("sqrt")``, which was already the case for :meth:`Series.apply` (:issue:`39116`) - :meth:`DataFrame.apply` can now accept non-callable DataFrame properties as strings, e.g. ``df.apply("size")``, which was already the case for :meth:`Series.apply` (:issue:`39116`) - :meth:`Series.apply` can now accept list-like or dictionary-like arguments that aren't lists or dictionaries, e.g. ``ser.apply(np.array(["sum", "mean"]))``, which was already the case for :meth:`DataFrame.apply` (:issue:`39140`) -- :meth:`DataFrame.plot.scatter` can now accept a categorical column as the argument to ``c`` (:issue:`31357`) +- :meth:`DataFrame.plot.scatter` can now accept a categorical column as the argument to ``c`` (:issue:`12380`, :issue:`31357`) - :meth:`.Styler.set_tooltips` allows on hover tooltips to be added to styled HTML dataframes. .. --------------------------------------------------------------------------- From 6e010915e1bfbc55557f91d39594b32e5a73c42d Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Wed, 27 Jan 2021 14:57:22 +0000 Subject: [PATCH 19/19] add versionadded tag --- doc/source/user_guide/visualization.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/source/user_guide/visualization.rst b/doc/source/user_guide/visualization.rst index 29ab73473bd6d..7911c58b9867e 100644 --- a/doc/source/user_guide/visualization.rst +++ b/doc/source/user_guide/visualization.rst @@ -589,6 +589,8 @@ each point: If a categorical column is passed to ``c``, then a discrete colorbar will be produced: +.. versionadded:: 1.3.0 + .. ipython:: python @savefig scatter_plot_categorical.png