Skip to content

Commit 297c59a

Browse files
authored
ENH: Allow column grouping in DataFrame.plot (#29944)
1 parent f30c7d7 commit 297c59a

File tree

4 files changed

+225
-6
lines changed

4 files changed

+225
-6
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Other enhancements
133133
- :meth:`MultiIndex.to_frame` now supports the argument ``allow_duplicates`` and raises on duplicate labels if it is missing or False (:issue:`45245`)
134134
- :class:`StringArray` now accepts array-likes containing nan-likes (``None``, ``np.nan``) for the ``values`` parameter in its constructor in addition to strings and :attr:`pandas.NA`. (:issue:`40839`)
135135
- Improved the rendering of ``categories`` in :class:`CategoricalIndex` (:issue:`45218`)
136+
- :meth:`DataFrame.plot` will now allow the ``subplots`` parameter to be a list of iterables specifying column groups, so that columns may be grouped together in the same subplot (:issue:`29688`).
136137
- :meth:`to_numeric` now preserves float64 arrays when downcasting would generate values not representable in float32 (:issue:`43693`)
137138
- :meth:`Series.reset_index` and :meth:`DataFrame.reset_index` now support the argument ``allow_duplicates`` (:issue:`44410`)
138139
- :meth:`.GroupBy.min` and :meth:`.GroupBy.max` now supports `Numba <https://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`45428`)

pandas/plotting/_core.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -649,8 +649,18 @@ class PlotAccessor(PandasObject):
649649
- 'hexbin' : hexbin plot (DataFrame only)
650650
ax : matplotlib axes object, default None
651651
An axes of the current figure.
652-
subplots : bool, default False
653-
Make separate subplots for each column.
652+
subplots : bool or sequence of iterables, default False
653+
Whether to group columns into subplots:
654+
655+
- ``False`` : No subplots will be used
656+
- ``True`` : Make separate subplots for each column.
657+
- sequence of iterables of column labels: Create a subplot for each
658+
group of columns. For example `[('a', 'c'), ('b', 'd')]` will
659+
create 2 subplots: one with columns 'a' and 'c', and one
660+
with columns 'b' and 'd'. Remaining columns that aren't specified
661+
will be plotted in additional subplots (one per column).
662+
.. versionadded:: 1.5.0
663+
654664
sharex : bool, default True if ax is None else False
655665
In case ``subplots=True``, share x axis and set some x axis labels
656666
to invisible; defaults to True if ax is None otherwise False if

pandas/plotting/_matplotlib/core.py

+128-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import (
44
TYPE_CHECKING,
55
Hashable,
6+
Iterable,
7+
Sequence,
68
)
79
import warnings
810

@@ -102,7 +104,7 @@ def __init__(
102104
data,
103105
kind=None,
104106
by: IndexLabel | None = None,
105-
subplots=False,
107+
subplots: bool | Sequence[Sequence[str]] = False,
106108
sharex=None,
107109
sharey=False,
108110
use_index=True,
@@ -166,8 +168,7 @@ def __init__(
166168
self.kind = kind
167169

168170
self.sort_columns = sort_columns
169-
170-
self.subplots = subplots
171+
self.subplots = self._validate_subplots_kwarg(subplots)
171172

172173
if sharex is None:
173174

@@ -253,6 +254,112 @@ def __init__(
253254

254255
self._validate_color_args()
255256

257+
def _validate_subplots_kwarg(
258+
self, subplots: bool | Sequence[Sequence[str]]
259+
) -> bool | list[tuple[int, ...]]:
260+
"""
261+
Validate the subplots parameter
262+
263+
- check type and content
264+
- check for duplicate columns
265+
- check for invalid column names
266+
- convert column names into indices
267+
- add missing columns in a group of their own
268+
See comments in code below for more details.
269+
270+
Parameters
271+
----------
272+
subplots : subplots parameters as passed to PlotAccessor
273+
274+
Returns
275+
-------
276+
validated subplots : a bool or a list of tuples of column indices. Columns
277+
in the same tuple will be grouped together in the resulting plot.
278+
"""
279+
280+
if isinstance(subplots, bool):
281+
return subplots
282+
elif not isinstance(subplots, Iterable):
283+
raise ValueError("subplots should be a bool or an iterable")
284+
285+
supported_kinds = (
286+
"line",
287+
"bar",
288+
"barh",
289+
"hist",
290+
"kde",
291+
"density",
292+
"area",
293+
"pie",
294+
)
295+
if self._kind not in supported_kinds:
296+
raise ValueError(
297+
"When subplots is an iterable, kind must be "
298+
f"one of {', '.join(supported_kinds)}. Got {self._kind}."
299+
)
300+
301+
if isinstance(self.data, ABCSeries):
302+
raise NotImplementedError(
303+
"An iterable subplots for a Series is not supported."
304+
)
305+
306+
columns = self.data.columns
307+
if isinstance(columns, ABCMultiIndex):
308+
raise NotImplementedError(
309+
"An iterable subplots for a DataFrame with a MultiIndex column "
310+
"is not supported."
311+
)
312+
313+
if columns.nunique() != len(columns):
314+
raise NotImplementedError(
315+
"An iterable subplots for a DataFrame with non-unique column "
316+
"labels is not supported."
317+
)
318+
319+
# subplots is a list of tuples where each tuple is a group of
320+
# columns to be grouped together (one ax per group).
321+
# we consolidate the subplots list such that:
322+
# - the tuples contain indices instead of column names
323+
# - the columns that aren't yet in the list are added in a group
324+
# of their own.
325+
# For example with columns from a to g, and
326+
# subplots = [(a, c), (b, f, e)],
327+
# we end up with [(ai, ci), (bi, fi, ei), (di,), (gi,)]
328+
# This way, we can handle self.subplots in a homogeneous manner
329+
# later.
330+
# TODO: also accept indices instead of just names?
331+
332+
out = []
333+
seen_columns: set[Hashable] = set()
334+
for group in subplots:
335+
if not is_list_like(group):
336+
raise ValueError(
337+
"When subplots is an iterable, each entry "
338+
"should be a list/tuple of column names."
339+
)
340+
idx_locs = columns.get_indexer_for(group)
341+
if (idx_locs == -1).any():
342+
bad_labels = np.extract(idx_locs == -1, group)
343+
raise ValueError(
344+
f"Column label(s) {list(bad_labels)} not found in the DataFrame."
345+
)
346+
else:
347+
unique_columns = set(group)
348+
duplicates = seen_columns.intersection(unique_columns)
349+
if duplicates:
350+
raise ValueError(
351+
"Each column should be in only one subplot. "
352+
f"Columns {duplicates} were found in multiple subplots."
353+
)
354+
seen_columns = seen_columns.union(unique_columns)
355+
out.append(tuple(idx_locs))
356+
357+
unseen_columns = columns.difference(seen_columns)
358+
for column in unseen_columns:
359+
idx_loc = columns.get_loc(column)
360+
out.append((idx_loc,))
361+
return out
362+
256363
def _validate_color_args(self):
257364
if (
258365
"color" in self.kwds
@@ -371,8 +478,11 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
371478

372479
def _setup_subplots(self):
373480
if self.subplots:
481+
naxes = (
482+
self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
483+
)
374484
fig, axes = create_subplots(
375-
naxes=self.nseries,
485+
naxes=naxes,
376486
sharex=self.sharex,
377487
sharey=self.sharey,
378488
figsize=self.figsize,
@@ -784,9 +894,23 @@ def _get_ax_layer(cls, ax, primary=True):
784894
else:
785895
return getattr(ax, "right_ax", ax)
786896

897+
def _col_idx_to_axis_idx(self, col_idx: int) -> int:
898+
"""Return the index of the axis where the column at col_idx should be plotted"""
899+
if isinstance(self.subplots, list):
900+
# Subplots is a list: some columns will be grouped together in the same ax
901+
return next(
902+
group_idx
903+
for (group_idx, group) in enumerate(self.subplots)
904+
if col_idx in group
905+
)
906+
else:
907+
# subplots is True: one ax per column
908+
return col_idx
909+
787910
def _get_ax(self, i: int):
788911
# get the twinx ax if appropriate
789912
if self.subplots:
913+
i = self._col_idx_to_axis_idx(i)
790914
ax = self.axes[i]
791915
ax = self._maybe_right_yaxis(ax, i)
792916
self.axes[i] = ax

pandas/tests/plotting/frame/test_frame.py

+84
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,90 @@ def test_plot_no_numeric_data(self):
20712071
with pytest.raises(TypeError, match="no numeric data to plot"):
20722072
df.plot()
20732073

2074+
@td.skip_if_no_scipy
2075+
@pytest.mark.parametrize(
2076+
"kind", ("line", "bar", "barh", "hist", "kde", "density", "area", "pie")
2077+
)
2078+
def test_group_subplot(self, kind):
2079+
d = {
2080+
"a": np.arange(10),
2081+
"b": np.arange(10) + 1,
2082+
"c": np.arange(10) + 1,
2083+
"d": np.arange(10),
2084+
"e": np.arange(10),
2085+
}
2086+
df = DataFrame(d)
2087+
2088+
axes = df.plot(subplots=[("b", "e"), ("c", "d")], kind=kind)
2089+
assert len(axes) == 3 # 2 groups + single column a
2090+
2091+
expected_labels = (["b", "e"], ["c", "d"], ["a"])
2092+
for ax, labels in zip(axes, expected_labels):
2093+
if kind != "pie":
2094+
self._check_legend_labels(ax, labels=labels)
2095+
if kind == "line":
2096+
assert len(ax.lines) == len(labels)
2097+
2098+
def test_group_subplot_series_notimplemented(self):
2099+
ser = Series(range(1))
2100+
msg = "An iterable subplots for a Series"
2101+
with pytest.raises(NotImplementedError, match=msg):
2102+
ser.plot(subplots=[("a",)])
2103+
2104+
def test_group_subplot_multiindex_notimplemented(self):
2105+
df = DataFrame(np.eye(2), columns=MultiIndex.from_tuples([(0, 1), (1, 2)]))
2106+
msg = "An iterable subplots for a DataFrame with a MultiIndex"
2107+
with pytest.raises(NotImplementedError, match=msg):
2108+
df.plot(subplots=[(0, 1)])
2109+
2110+
def test_group_subplot_nonunique_cols_notimplemented(self):
2111+
df = DataFrame(np.eye(2), columns=["a", "a"])
2112+
msg = "An iterable subplots for a DataFrame with non-unique"
2113+
with pytest.raises(NotImplementedError, match=msg):
2114+
df.plot(subplots=[("a",)])
2115+
2116+
@pytest.mark.parametrize(
2117+
"subplots, expected_msg",
2118+
[
2119+
(123, "subplots should be a bool or an iterable"),
2120+
("a", "each entry should be a list/tuple"), # iterable of non-iterable
2121+
((1,), "each entry should be a list/tuple"), # iterable of non-iterable
2122+
(("a",), "each entry should be a list/tuple"), # iterable of strings
2123+
],
2124+
)
2125+
def test_group_subplot_bad_input(self, subplots, expected_msg):
2126+
# Make sure error is raised when subplots is not a properly
2127+
# formatted iterable. Only iterables of iterables are permitted, and
2128+
# entries should not be strings.
2129+
d = {"a": np.arange(10), "b": np.arange(10)}
2130+
df = DataFrame(d)
2131+
2132+
with pytest.raises(ValueError, match=expected_msg):
2133+
df.plot(subplots=subplots)
2134+
2135+
def test_group_subplot_invalid_column_name(self):
2136+
d = {"a": np.arange(10), "b": np.arange(10)}
2137+
df = DataFrame(d)
2138+
2139+
with pytest.raises(ValueError, match=r"Column label\(s\) \['bad_name'\]"):
2140+
df.plot(subplots=[("a", "bad_name")])
2141+
2142+
def test_group_subplot_duplicated_column(self):
2143+
d = {"a": np.arange(10), "b": np.arange(10), "c": np.arange(10)}
2144+
df = DataFrame(d)
2145+
2146+
with pytest.raises(ValueError, match="should be in only one subplot"):
2147+
df.plot(subplots=[("a", "b"), ("a", "c")])
2148+
2149+
@pytest.mark.parametrize("kind", ("box", "scatter", "hexbin"))
2150+
def test_group_subplot_invalid_kind(self, kind):
2151+
d = {"a": np.arange(10), "b": np.arange(10)}
2152+
df = DataFrame(d)
2153+
with pytest.raises(
2154+
ValueError, match="When subplots is an iterable, kind must be one of"
2155+
):
2156+
df.plot(subplots=[("a", "b")], kind=kind)
2157+
20742158
@pytest.mark.parametrize(
20752159
"index_name, old_label, new_label",
20762160
[

0 commit comments

Comments
 (0)