-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: categorical scatter plot #34293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
d8d8c3c
a167313
d10c5c1
7294aa2
50cd05f
4579142
6d3fe9e
3846804
ef4b03d
cf218f7
7aae164
0ab903d
b93ddf1
5fbc117
572ecfc
b2a8b28
efaaae6
b0a8cfa
4c15a83
b65b103
6560bb0
cce3461
6e01091
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from pandas.util._decorators import cache_readonly | ||
|
||
from pandas.core.dtypes.common import ( | ||
is_categorical_dtype, | ||
is_extension_array_dtype, | ||
is_float, | ||
is_float_dtype, | ||
|
@@ -388,6 +389,10 @@ def result(self): | |
return self.axes[0] | ||
|
||
def _convert_to_ndarray(self, data): | ||
# GH31357: categorical columns are processed separately | ||
if is_categorical_dtype(data): | ||
return data | ||
|
||
# GH32073: cast to float if values contain nulled integers | ||
if ( | ||
is_integer_dtype(data.dtype) or is_float_dtype(data.dtype) | ||
|
@@ -974,7 +979,7 @@ def _plot_colorbar(self, ax: Axes, **kwds): | |
|
||
if mpl_ge_3_0_0(): | ||
# The workaround below is no longer necessary. | ||
return | ||
Comment on lines
981
to
-977
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment in the line above still relevant?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so, yes, it's just saying that points = ax.get_position().get_points()
cbar_points = cbar.ax.get_position().get_points()
cbar.ax.set_position(
[
cbar_points[0, 0],
points[0, 1],
cbar_points[1, 0] - cbar_points[0, 0],
points[1, 1] - points[0, 1],
]
)
# To see the discrepancy in axis heights uncomment
# the following two lines:
# print(points[1, 1] - points[0, 1])
# print(cbar_points[1, 1] - cbar_points[0, 1])
return cbar is no longer necessary if mpl is of a modern enough version |
||
return cbar | ||
|
||
points = ax.get_position().get_points() | ||
cbar_points = cbar.ax.get_position().get_points() | ||
|
@@ -992,6 +997,8 @@ def _plot_colorbar(self, ax: Axes, **kwds): | |
# print(points[1, 1] - points[0, 1]) | ||
# print(cbar_points[1, 1] - cbar_points[0, 1]) | ||
|
||
return cbar | ||
|
||
|
||
class ScatterPlot(PlanePlot): | ||
_kind = "scatter" | ||
|
@@ -1014,6 +1021,8 @@ def _make_plot(self): | |
|
||
c_is_column = is_hashable(c) and c in self.data.columns | ||
|
||
color_by_categorical = c_is_column and is_categorical_dtype(self.data[c]) | ||
|
||
# pandas uses colormap, matplotlib uses cmap. | ||
cmap = self.colormap or "Greys" | ||
cmap = self.plt.cm.get_cmap(cmap) | ||
|
@@ -1024,11 +1033,22 @@ def _make_plot(self): | |
c_values = self.plt.rcParams["patch.facecolor"] | ||
elif color is not None: | ||
c_values = color | ||
elif color_by_categorical: | ||
c_values = self.data[c].cat.codes | ||
elif c_is_column: | ||
c_values = self.data[c].values | ||
else: | ||
c_values = c | ||
|
||
if color_by_categorical: | ||
from matplotlib import colors | ||
|
||
n_cats = len(self.data[c].cat.categories) | ||
cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)]) | ||
bounds = np.linspace(0, n_cats, n_cats + 1) | ||
norm = colors.BoundaryNorm(bounds, cmap.N) | ||
else: | ||
norm = None | ||
# plot colorbar if | ||
# 1. colormap is assigned, and | ||
# 2.`c` is a column containing only numeric values | ||
|
@@ -1045,11 +1065,15 @@ def _make_plot(self): | |
c=c_values, | ||
label=label, | ||
cmap=cmap, | ||
norm=norm, | ||
**self.kwds, | ||
) | ||
if cb: | ||
cbar_label = c if c_is_column else "" | ||
self._plot_colorbar(ax, label=cbar_label) | ||
cbar = self._plot_colorbar(ax, label=cbar_label) | ||
if color_by_categorical: | ||
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats)) | ||
cbar.ax.set_yticklabels(self.data[c].cat.categories) | ||
|
||
if label is not None: | ||
self._add_legend_handle(scatter, label) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -696,6 +696,37 @@ def test_scatterplot_object_data(self): | |
_check_plot_works(df.plot.scatter, x="a", y="b") | ||
_check_plot_works(df.plot.scatter, x=0, y=1) | ||
|
||
@pytest.mark.parametrize("ordered", [True, False]) | ||
@pytest.mark.parametrize( | ||
"categories", | ||
(["setosa", "versicolor", "virginica"], ["versicolor", "virginica", "setosa"]), | ||
) | ||
def test_scatterplot_color_by_categorical(self, ordered, categories): | ||
Comment on lines
+699
to
+704
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to add assertions on the colorbar ticklabels for the ordered case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't find a simple way to do this using the public API from mpl, but can look into it further (or do you have a suggestion?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I look at my comment once again and I cannot figure out what kind of assertions I asked you to consider. |
||
df = DataFrame( | ||
[[5.1, 3.5], [4.9, 3.0], [7.0, 3.2], [6.4, 3.2], [5.9, 3.0]], | ||
columns=["length", "width"], | ||
) | ||
df["species"] = pd.Categorical( | ||
["setosa", "setosa", "virginica", "virginica", "versicolor"], | ||
ordered=ordered, | ||
categories=categories, | ||
) | ||
ax = df.plot.scatter(x=0, y=1, c="species") | ||
(colorbar_collection,) = ax.collections | ||
colorbar = colorbar_collection.colorbar | ||
|
||
expected_ticks = np.array([0.5, 1.5, 2.5]) | ||
result_ticks = colorbar.get_ticks() | ||
tm.assert_numpy_array_equal(result_ticks, expected_ticks) | ||
|
||
expected_boundaries = np.array([0.0, 1.0, 2.0, 3.0]) | ||
result_boundaries = colorbar._boundaries | ||
tm.assert_numpy_array_equal(result_boundaries, expected_boundaries) | ||
|
||
expected_yticklabels = categories | ||
result_yticklabels = [i.get_text() for i in colorbar.ax.get_ymajorticklabels()] | ||
assert all(i == j for i, j in zip(result_yticklabels, expected_yticklabels)) | ||
|
||
@pytest.mark.parametrize("x, y", [("x", "y"), ("y", "x"), ("y", "y")]) | ||
def test_plot_scatter_with_categorical_data(self, x, y): | ||
# after fixing GH 18755, should be able to plot categorical data | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's mention the other issue here (and just close it too)