From 183a686cc6c2cea345235448f1ff9eae7f965016 Mon Sep 17 00:00:00 2001 From: myenugula Date: Sun, 13 Apr 2025 15:02:41 +0800 Subject: [PATCH] BUG: Fix scatter plot colors in groupby context to match line plot behavior (#59846) --- doc/source/whatsnew/v3.0.0.rst | 1 + pandas/core/groupby/groupby.py | 49 ++++++++++++++++++++++++--- pandas/tests/plotting/test_groupby.py | 48 ++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 4 deletions(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index e6fafc8b1b14c..700eca0d9a795 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -762,6 +762,7 @@ Plotting - Bug in :meth:`DataFrame.plot.bar` with ``stacked=True`` where labels on stacked bars with zero-height segments were incorrectly positioned at the base instead of the label position of the previous segment (:issue:`59429`) - Bug in :meth:`DataFrame.plot.line` raising ``ValueError`` when set both color and a ``dict`` style (:issue:`59461`) - Bug in :meth:`DataFrame.plot` that causes a shift to the right when the frequency multiplier is greater than one. (:issue:`57587`) +- Bug in :meth:`DataFrameGroupBy.plot` with ``kind="scatter"`` where all groups used the same color instead of different colors for each group (:issue:`59846`) - Bug in :meth:`Series.plot` with ``kind="pie"`` with :class:`ArrowDtype` (:issue:`59192`) Groupby/resample/rolling diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f9438b348c140..36a02ac34e8e8 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -431,11 +431,52 @@ def __init__(self, groupby: GroupBy) -> None: self._groupby = groupby def __call__(self, *args, **kwargs): - def f(self): - return self.plot(*args, **kwargs) + # Special case for scatter plots to enable automatic colors in groupby context + if kwargs.get("kind") == "scatter": + # Get the groupby data and iterator + obj = self._groupby._selected_obj + + # Handle subplots + if kwargs.pop("subplots", False): + return self._subplots(obj, *args, **kwargs) + + # Plot each group with color from index position + results = {} + for i, (name, group) in enumerate(self._groupby): + if self._groupby._selection is not None: + group = group[self._groupby._selection] + + # Create a copy of kwargs with explicit color for each group + plot_kwargs = kwargs.copy() + + # Get colors from matplotlib's color cycle + import matplotlib as mpl + + colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + # Set explicit color for this group + plot_kwargs["color"] = colors[i % len(colors)] + + # Create the plot for this group + # Remove 'kind' to avoid duplicate keyword argument error + scatter_kwargs = plot_kwargs.copy() + if "kind" in scatter_kwargs: + del scatter_kwargs["kind"] + + if hasattr(group.plot, "scatter"): + result = group.plot.scatter(*args, **scatter_kwargs) + else: + result = group.plot(*args, kind="scatter", **scatter_kwargs) + + results[name] = result + + return results + else: + # Original implementation for non-scatter plots + def f(self): + return self.plot(*args, **kwargs) - f.__name__ = "plot" - return self._groupby._python_apply_general(f, self._groupby._selected_obj) + f.__name__ = "plot" + return self._groupby._python_apply_general(f, self._groupby._selected_obj) def __getattr__(self, name: str): def attr(*args, **kwargs): diff --git a/pandas/tests/plotting/test_groupby.py b/pandas/tests/plotting/test_groupby.py index 0cb125d822fd1..f87657cd4c0c4 100644 --- a/pandas/tests/plotting/test_groupby.py +++ b/pandas/tests/plotting/test_groupby.py @@ -152,3 +152,51 @@ def test_groupby_hist_series_with_legend_raises(self): with pytest.raises(ValueError, match="Cannot use both legend and label"): g.hist(legend=True, label="d") + + def test_groupby_scatter_colors(self): + # GH 59846 - Test that scatter plots use different colors for different groups + # similar to how line plots do + from matplotlib.collections import PathCollection + import matplotlib.pyplot as plt + + # Create test data with distinct groups + df = DataFrame( + { + "x": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "y": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "group": ["A", "A", "A", "B", "B", "B", "C", "C", "C"], + } + ) + + # Set up a figure with both line and scatter plots + fig, (ax1, ax2) = plt.subplots(1, 2) + + # Plot line chart (known to use different colors for different groups) + df.groupby("group").plot(x="x", y="y", ax=ax1, kind="line") + + # Plot scatter chart (should also use different colors for different groups) + df.groupby("group").plot(x="x", y="y", ax=ax2, kind="scatter") + + # Get the colors used in the line plot and scatter plot + line_colors = [line.get_color() for line in ax1.get_lines()] + + # Get scatter colors + scatter_colors = [] + for collection in ax2.collections: + if isinstance(collection, PathCollection): # This is a scatter plot + # Get the face colors (might be array of RGBA values) + face_colors = collection.get_facecolor() + # If multiple points with same color, we get the first one + if face_colors.ndim > 1: + scatter_colors.append(tuple(face_colors[0])) + else: + scatter_colors.append(tuple(face_colors)) + + # Assert that we have the right number of colors (one per group) + assert len(line_colors) == 3 + assert len(scatter_colors) == 3 + + # Assert that the colors are all different (fixed behavior) + assert len(set(scatter_colors)) == 3 + + plt.close(fig)