Skip to content

Commit 183a686

Browse files
committed
BUG: Fix scatter plot colors in groupby context to match line plot behavior (#59846)
1 parent 5736b96 commit 183a686

File tree

3 files changed

+94
-4
lines changed

3 files changed

+94
-4
lines changed

Diff for: doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ Plotting
762762
- Bug in :meth:`DataFrame.plot.bar` with ``stacked=True`` where labels on stacked bars with zero-height segments were incorrectly positioned at the base instead of the label position of the previous segment (:issue:`59429`)
763763
- Bug in :meth:`DataFrame.plot.line` raising ``ValueError`` when set both color and a ``dict`` style (:issue:`59461`)
764764
- Bug in :meth:`DataFrame.plot` that causes a shift to the right when the frequency multiplier is greater than one. (:issue:`57587`)
765+
- Bug in :meth:`DataFrameGroupBy.plot` with ``kind="scatter"`` where all groups used the same color instead of different colors for each group (:issue:`59846`)
765766
- Bug in :meth:`Series.plot` with ``kind="pie"`` with :class:`ArrowDtype` (:issue:`59192`)
766767

767768
Groupby/resample/rolling

Diff for: pandas/core/groupby/groupby.py

+45-4
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,52 @@ def __init__(self, groupby: GroupBy) -> None:
431431
self._groupby = groupby
432432

433433
def __call__(self, *args, **kwargs):
434-
def f(self):
435-
return self.plot(*args, **kwargs)
434+
# Special case for scatter plots to enable automatic colors in groupby context
435+
if kwargs.get("kind") == "scatter":
436+
# Get the groupby data and iterator
437+
obj = self._groupby._selected_obj
438+
439+
# Handle subplots
440+
if kwargs.pop("subplots", False):
441+
return self._subplots(obj, *args, **kwargs)
442+
443+
# Plot each group with color from index position
444+
results = {}
445+
for i, (name, group) in enumerate(self._groupby):
446+
if self._groupby._selection is not None:
447+
group = group[self._groupby._selection]
448+
449+
# Create a copy of kwargs with explicit color for each group
450+
plot_kwargs = kwargs.copy()
451+
452+
# Get colors from matplotlib's color cycle
453+
import matplotlib as mpl
454+
455+
colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
456+
# Set explicit color for this group
457+
plot_kwargs["color"] = colors[i % len(colors)]
458+
459+
# Create the plot for this group
460+
# Remove 'kind' to avoid duplicate keyword argument error
461+
scatter_kwargs = plot_kwargs.copy()
462+
if "kind" in scatter_kwargs:
463+
del scatter_kwargs["kind"]
464+
465+
if hasattr(group.plot, "scatter"):
466+
result = group.plot.scatter(*args, **scatter_kwargs)
467+
else:
468+
result = group.plot(*args, kind="scatter", **scatter_kwargs)
469+
470+
results[name] = result
471+
472+
return results
473+
else:
474+
# Original implementation for non-scatter plots
475+
def f(self):
476+
return self.plot(*args, **kwargs)
436477

437-
f.__name__ = "plot"
438-
return self._groupby._python_apply_general(f, self._groupby._selected_obj)
478+
f.__name__ = "plot"
479+
return self._groupby._python_apply_general(f, self._groupby._selected_obj)
439480

440481
def __getattr__(self, name: str):
441482
def attr(*args, **kwargs):

Diff for: pandas/tests/plotting/test_groupby.py

+48
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,51 @@ def test_groupby_hist_series_with_legend_raises(self):
152152

153153
with pytest.raises(ValueError, match="Cannot use both legend and label"):
154154
g.hist(legend=True, label="d")
155+
156+
def test_groupby_scatter_colors(self):
157+
# GH 59846 - Test that scatter plots use different colors for different groups
158+
# similar to how line plots do
159+
from matplotlib.collections import PathCollection
160+
import matplotlib.pyplot as plt
161+
162+
# Create test data with distinct groups
163+
df = DataFrame(
164+
{
165+
"x": [1, 2, 3, 4, 5, 6, 7, 8, 9],
166+
"y": [1, 2, 3, 4, 5, 6, 7, 8, 9],
167+
"group": ["A", "A", "A", "B", "B", "B", "C", "C", "C"],
168+
}
169+
)
170+
171+
# Set up a figure with both line and scatter plots
172+
fig, (ax1, ax2) = plt.subplots(1, 2)
173+
174+
# Plot line chart (known to use different colors for different groups)
175+
df.groupby("group").plot(x="x", y="y", ax=ax1, kind="line")
176+
177+
# Plot scatter chart (should also use different colors for different groups)
178+
df.groupby("group").plot(x="x", y="y", ax=ax2, kind="scatter")
179+
180+
# Get the colors used in the line plot and scatter plot
181+
line_colors = [line.get_color() for line in ax1.get_lines()]
182+
183+
# Get scatter colors
184+
scatter_colors = []
185+
for collection in ax2.collections:
186+
if isinstance(collection, PathCollection): # This is a scatter plot
187+
# Get the face colors (might be array of RGBA values)
188+
face_colors = collection.get_facecolor()
189+
# If multiple points with same color, we get the first one
190+
if face_colors.ndim > 1:
191+
scatter_colors.append(tuple(face_colors[0]))
192+
else:
193+
scatter_colors.append(tuple(face_colors))
194+
195+
# Assert that we have the right number of colors (one per group)
196+
assert len(line_colors) == 3
197+
assert len(scatter_colors) == 3
198+
199+
# Assert that the colors are all different (fixed behavior)
200+
assert len(set(scatter_colors)) == 3
201+
202+
plt.close(fig)

0 commit comments

Comments
 (0)