Skip to content

ENH: categorical scatter plot #34293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d8d8c3c
add legend with colors if coloring by categorical
MarcoGorelli May 21, 2020
a167313
add legend with colors if coloring by categorical
MarcoGorelli May 21, 2020
d10c5c1
add legend with colors if coloring by categorical
MarcoGorelli May 21, 2020
7294aa2
add legend with colors if coloring by categorical
MarcoGorelli May 21, 2020
50cd05f
add legend with colors if coloring by categorical
MarcoGorelli May 21, 2020
4579142
add legend with colors if coloring by categorical
MarcoGorelli May 21, 2020
6d3fe9e
add test
MarcoGorelli May 21, 2020
3846804
revert empty line
MarcoGorelli May 21, 2020
ef4b03d
discrete colorbar in case of ordered categorical
MarcoGorelli May 22, 2020
cf218f7
cleanup
MarcoGorelli May 22, 2020
7aae164
cleanup
MarcoGorelli May 22, 2020
0ab903d
plot colorbar in both cases
MarcoGorelli May 22, 2020
b93ddf1
update test
MarcoGorelli May 22, 2020
5fbc117
update test
MarcoGorelli May 22, 2020
572ecfc
Merge remote-tracking branch 'upstream/master' into categorical-scatter
MarcoGorelli Jan 3, 2021
b2a8b28
:art:
MarcoGorelli Jan 3, 2021
efaaae6
Merge remote-tracking branch 'upstream/master' into categorical-scatter
MarcoGorelli Jan 17, 2021
b0a8cfa
simplify logic
MarcoGorelli Jan 17, 2021
4c15a83
whatsnew entry
MarcoGorelli Jan 17, 2021
b65b103
Merge remote-tracking branch 'upstream/master' into categorical-scatter
MarcoGorelli Jan 18, 2021
6560bb0
Merge remote-tracking branch 'upstream/master' into categorical-scatter
MarcoGorelli Jan 24, 2021
cce3461
add example to visualisation
MarcoGorelli Jan 26, 2021
6e01091
add versionadded tag
MarcoGorelli Jan 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.common import (
is_categorical_dtype,
is_extension_array_dtype,
is_float,
is_float_dtype,
Expand Down Expand Up @@ -440,7 +441,9 @@ def _compute_plot_data(self):
if is_empty:
raise TypeError("no numeric data to plot")

self.data = numeric_data.apply(self._convert_to_ndarray)
self.data = numeric_data.apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just change _convert_to_ndarray to handle this case? (and mention in the doc-string)

lambda x: self._convert_to_ndarray(x) if not is_categorical_dtype(x) else x
)

def _make_plot(self):
raise AbstractMethodError(self)
Expand Down Expand Up @@ -973,7 +976,7 @@ def _plot_colorbar(self, ax: "Axes", **kwds):

if mpl_ge_3_0_0():
# The workaround below is no longer necessary.
return
Comment on lines 981 to -977
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment in the line above still relevant?

# The workaround below is no longer necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yes, it's just saying that

        points = ax.get_position().get_points()
        cbar_points = cbar.ax.get_position().get_points()

        cbar.ax.set_position(
            [
                cbar_points[0, 0],
                points[0, 1],
                cbar_points[1, 0] - cbar_points[0, 0],
                points[1, 1] - points[0, 1],
            ]
        )
        # To see the discrepancy in axis heights uncomment
        # the following two lines:
        # print(points[1, 1] - points[0, 1])
        # print(cbar_points[1, 1] - cbar_points[0, 1])

        return cbar

is no longer necessary if mpl is of a modern enough version

return cbar

points = ax.get_position().get_points()
cbar_points = cbar.ax.get_position().get_points()
Expand All @@ -991,6 +994,8 @@ def _plot_colorbar(self, ax: "Axes", **kwds):
# print(points[1, 1] - points[0, 1])
# print(cbar_points[1, 1] - cbar_points[0, 1])

return cbar


class ScatterPlot(PlanePlot):
_kind = "scatter"
Expand All @@ -1013,6 +1018,8 @@ def _make_plot(self):

c_is_column = is_hashable(c) and c in self.data.columns

color_by_categorical = c_is_column and is_categorical_dtype(self.data[c])

# pandas uses colormap, matplotlib uses cmap.
cmap = self.colormap or "Greys"
cmap = self.plt.cm.get_cmap(cmap)
Expand All @@ -1024,10 +1031,22 @@ def _make_plot(self):
elif color is not None:
c_values = color
elif c_is_column:
c_values = self.data[c].values
if color_by_categorical:
c_values = self.data[c].cat.codes
else:
c_values = self.data[c].values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe slightly modify the logic here?

...
elif color_by_categorical:
  c_values = self.data[c].cat.codes
elif c_is_column:
  c_values = self.data[c].values
else:
  c_values = c

This allows one to eliminate the inner if-else statement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good catch, thank you!

else:
c_values = c

if color_by_categorical:
from matplotlib import colors

n_cats = len(self.data[c].cat.categories)
cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
bounds = np.linspace(0, n_cats, n_cats + 1)
norm = colors.BoundaryNorm(bounds, cmap.N)
else:
norm = None
# plot colorbar if
# 1. colormap is assigned, and
# 2.`c` is a column containing only numeric values
Expand All @@ -1044,11 +1063,15 @@ def _make_plot(self):
c=c_values,
label=label,
cmap=cmap,
norm=norm,
**self.kwds,
)
if cb:
cbar_label = c if c_is_column else ""
self._plot_colorbar(ax, label=cbar_label)
cbar = self._plot_colorbar(ax, label=cbar_label)
if color_by_categorical:
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
cbar.ax.set_yticklabels(self.data[c].cat.categories)

if label is not None:
self._add_legend_handle(scatter, label)
Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/plotting/frame/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,36 @@ def test_scatterplot_object_data(self):
_check_plot_works(df.plot.scatter, x="a", y="b")
_check_plot_works(df.plot.scatter, x=0, y=1)

@pytest.mark.parametrize("ordered", [True, False])
@pytest.mark.parametrize(
"categories",
(["setosa", "versicolor", "virginica"], ["versicolor", "virginica", "setosa"]),
)
def test_scatterplot_color_by_categorical(self, ordered, categories):
Comment on lines +699 to +704
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to add assertions on the colorbar ticklabels for the ordered case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find a simple way to do this using the public API from mpl, but can look into it further (or do you have a suggestion?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I look at my comment once again and I cannot figure out what kind of assertions I asked you to consider.
You already do check the ticklabels at the colorbar against the expectations.
Sorry for the noise.

df = DataFrame(
[[5.1, 3.5], [4.9, 3.0], [7.0, 3.2], [6.4, 3.2], [5.9, 3.0]],
columns=["length", "width"],
)
df["species"] = pd.Categorical(
["setosa", "setosa", "virginica", "virginica", "versicolor"],
ordered=ordered,
categories=categories,
)
ax = df.plot.scatter(x=0, y=1, c="species")
colorbar = ax.collections[0].colorbar

expected_ticks = np.array([0.5, 1.5, 2.5])
result_ticks = colorbar.get_ticks()
tm.assert_numpy_array_equal(result_ticks, expected_ticks)

expected_boundaries = np.array([0.0, 1.0, 2.0, 3.0])
result_boundaries = colorbar._boundaries
tm.assert_numpy_array_equal(result_boundaries, expected_boundaries)

expected_yticklabels = categories
result_yticklabels = [i.get_text() for i in colorbar.ax.get_ymajorticklabels()]
assert all(i == j for i, j in zip(result_yticklabels, expected_yticklabels))

@pytest.mark.parametrize("x, y", [("x", "y"), ("y", "x"), ("y", "y")])
def test_plot_scatter_with_categorical_data(self, x, y):
# after fixing GH 18755, should be able to plot categorical data
Expand Down