Skip to content

Commit 09f1c7c

Browse files
committed
BUG: Fix scatter plot colors in groupby context to match line plot behavior (pandas-dev#59846)
1 parent 5736b96 commit 09f1c7c

File tree

3 files changed

+80
-7
lines changed

3 files changed

+80
-7
lines changed

Diff for: doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ Plotting
762762
- Bug in :meth:`DataFrame.plot.bar` with ``stacked=True`` where labels on stacked bars with zero-height segments were incorrectly positioned at the base instead of the label position of the previous segment (:issue:`59429`)
763763
- Bug in :meth:`DataFrame.plot.line` raising ``ValueError`` when set both color and a ``dict`` style (:issue:`59461`)
764764
- Bug in :meth:`DataFrame.plot` that causes a shift to the right when the frequency multiplier is greater than one. (:issue:`57587`)
765+
- Bug in :meth:`DataFrameGroupBy.plot` with scatter colors (:issue:`59846`)
765766
- Bug in :meth:`Series.plot` with ``kind="pie"`` with :class:`ArrowDtype` (:issue:`59192`)
766767

767768
Groupby/resample/rolling

Diff for: pandas/core/groupby/groupby.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,35 @@ def __init__(self, groupby: GroupBy) -> None:
431431
self._groupby = groupby
432432

433433
def __call__(self, *args, **kwargs):
434-
def f(self):
435-
return self.plot(*args, **kwargs)
434+
# Special case for scatter plots to enable automatic colors, similar to line plots
435+
if kwargs.get("kind") == "scatter":
436+
import matplotlib.pyplot as plt
436437

437-
f.__name__ = "plot"
438-
return self._groupby._python_apply_general(f, self._groupby._selected_obj)
438+
# Get colors from matplotlib's color cycle (similar to what LinePlot uses)
439+
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
440+
441+
# Determine the axis to plot on
442+
if "ax" in kwargs:
443+
ax = kwargs["ax"]
444+
else:
445+
_, ax = plt.subplots()
446+
447+
# Plot each group with a different color
448+
results = {}
449+
for i, (name, group) in enumerate(self._groupby):
450+
group_kwargs = kwargs.copy()
451+
group_kwargs["ax"] = ax
452+
group_kwargs["color"] = colors[i % len(colors)]
453+
results[name] = group.plot(*args, **group_kwargs)
454+
455+
return results
456+
else:
457+
# Original implementation for non-scatter plots
458+
def f(self):
459+
return self.plot(*args, **kwargs)
460+
461+
f.__name__ = "plot"
462+
return self._groupby._python_apply_general(f, self._groupby._selected_obj)
439463

440464
def __getattr__(self, name: str):
441465
def attr(*args, **kwargs):
@@ -546,8 +570,7 @@ def groups(self) -> dict[Hashable, Index]:
546570
2023-02-15 4
547571
dtype: int64
548572
>>> ser.resample("MS").groups
549-
{Timestamp('2023-01-01 00:00:00'): np.int64(2),
550-
Timestamp('2023-02-01 00:00:00'): np.int64(4)}
573+
{Timestamp('2023-01-01 00:00:00'): 2, Timestamp('2023-02-01 00:00:00'): 4}
551574
"""
552575
if isinstance(self.keys, list) and len(self.keys) == 1:
553576
warnings.warn(
@@ -614,7 +637,7 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
614637
toucan 1 5 6
615638
eagle 7 8 9
616639
>>> df.groupby(by=["a"]).indices
617-
{np.int64(1): array([0, 1]), np.int64(7): array([2])}
640+
{1: array([0, 1]), 7: array([2])}
618641
619642
For Resampler:
620643

Diff for: pandas/tests/plotting/test_groupby.py

+49
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,52 @@ def test_groupby_hist_series_with_legend_raises(self):
152152

153153
with pytest.raises(ValueError, match="Cannot use both legend and label"):
154154
g.hist(legend=True, label="d")
155+
156+
def test_groupby_scatter_colors(self):
157+
# GH 59846 - Test that scatter plots use different colors for different groups
158+
# similar to how line plots do
159+
from matplotlib.collections import PathCollection
160+
import matplotlib.pyplot as plt
161+
162+
# Create test data with distinct groups
163+
df = DataFrame(
164+
{
165+
"x": [1, 2, 3, 4, 5, 6, 7, 8, 9],
166+
"y": [1, 2, 3, 4, 5, 6, 7, 8, 9],
167+
"group": ["A", "A", "A", "B", "B", "B", "C", "C", "C"],
168+
}
169+
)
170+
171+
# Set up a figure with both line and scatter plots
172+
fig, (ax1, ax2) = plt.subplots(1, 2)
173+
174+
# Plot line chart (known to use different colors for different groups)
175+
df.groupby("group").plot(x="x", y="y", ax=ax1, kind="line")
176+
177+
# Plot scatter chart (should also use different colors for different groups)
178+
df.groupby("group").plot(x="x", y="y", ax=ax2, kind="scatter")
179+
180+
# Get the colors used in the line plot and scatter plot
181+
line_colors = [line.get_color() for line in ax1.get_lines()]
182+
183+
# Get scatter colors
184+
scatter_colors = []
185+
for collection in ax2.collections:
186+
if isinstance(collection, PathCollection): # This is a scatter plot
187+
# Get the face colors (might be array of RGBA values)
188+
face_colors = collection.get_facecolor()
189+
# If multiple points with same color, we get the first one
190+
if face_colors.ndim > 1:
191+
scatter_colors.append(tuple(face_colors[0]))
192+
else:
193+
scatter_colors.append(tuple(face_colors))
194+
195+
# Assert that we have the right number of colors (one per group)
196+
assert len(line_colors) == 3
197+
assert len(scatter_colors) == 3
198+
199+
# Assert that the colors are all different (fixed behavior)
200+
assert len(set(line_colors)) == 3
201+
assert len(set(scatter_colors)) == 3
202+
203+
plt.close(fig)

0 commit comments

Comments
 (0)