diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 45c32d689bd5b..fccfbb98f2591 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -133,6 +133,7 @@ Other enhancements - :meth:`MultiIndex.to_frame` now supports the argument ``allow_duplicates`` and raises on duplicate labels if it is missing or False (:issue:`45245`) - :class:`StringArray` now accepts array-likes containing nan-likes (``None``, ``np.nan``) for the ``values`` parameter in its constructor in addition to strings and :attr:`pandas.NA`. (:issue:`40839`) - Improved the rendering of ``categories`` in :class:`CategoricalIndex` (:issue:`45218`) +- :meth:`DataFrame.plot` will now allow the ``subplots`` parameter to be a list of iterables specifying column groups, so that columns may be grouped together in the same subplot (:issue:`29688`). - :meth:`to_numeric` now preserves float64 arrays when downcasting would generate values not representable in float32 (:issue:`43693`) - :meth:`Series.reset_index` and :meth:`DataFrame.reset_index` now support the argument ``allow_duplicates`` (:issue:`44410`) - :meth:`.GroupBy.min` and :meth:`.GroupBy.max` now supports `Numba `_ execution with the ``engine`` keyword (:issue:`45428`) diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py index 29110f63787a6..929ddb52aea6d 100644 --- a/pandas/plotting/_core.py +++ b/pandas/plotting/_core.py @@ -649,8 +649,18 @@ class PlotAccessor(PandasObject): - 'hexbin' : hexbin plot (DataFrame only) ax : matplotlib axes object, default None An axes of the current figure. - subplots : bool, default False - Make separate subplots for each column. + subplots : bool or sequence of iterables, default False + Whether to group columns into subplots: + + - ``False`` : No subplots will be used + - ``True`` : Make separate subplots for each column. + - sequence of iterables of column labels: Create a subplot for each + group of columns. For example `[('a', 'c'), ('b', 'd')]` will + create 2 subplots: one with columns 'a' and 'c', and one + with columns 'b' and 'd'. Remaining columns that aren't specified + will be plotted in additional subplots (one per column). + .. versionadded:: 1.5.0 + sharex : bool, default True if ax is None else False In case ``subplots=True``, share x axis and set some x axis labels to invisible; defaults to True if ax is None otherwise False if diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 3899885bd95f1..5fceb14b9d1cc 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -3,6 +3,8 @@ from typing import ( TYPE_CHECKING, Hashable, + Iterable, + Sequence, ) import warnings @@ -102,7 +104,7 @@ def __init__( data, kind=None, by: IndexLabel | None = None, - subplots=False, + subplots: bool | Sequence[Sequence[str]] = False, sharex=None, sharey=False, use_index=True, @@ -166,8 +168,7 @@ def __init__( self.kind = kind self.sort_columns = sort_columns - - self.subplots = subplots + self.subplots = self._validate_subplots_kwarg(subplots) if sharex is None: @@ -253,6 +254,112 @@ def __init__( self._validate_color_args() + def _validate_subplots_kwarg( + self, subplots: bool | Sequence[Sequence[str]] + ) -> bool | list[tuple[int, ...]]: + """ + Validate the subplots parameter + + - check type and content + - check for duplicate columns + - check for invalid column names + - convert column names into indices + - add missing columns in a group of their own + See comments in code below for more details. + + Parameters + ---------- + subplots : subplots parameters as passed to PlotAccessor + + Returns + ------- + validated subplots : a bool or a list of tuples of column indices. Columns + in the same tuple will be grouped together in the resulting plot. + """ + + if isinstance(subplots, bool): + return subplots + elif not isinstance(subplots, Iterable): + raise ValueError("subplots should be a bool or an iterable") + + supported_kinds = ( + "line", + "bar", + "barh", + "hist", + "kde", + "density", + "area", + "pie", + ) + if self._kind not in supported_kinds: + raise ValueError( + "When subplots is an iterable, kind must be " + f"one of {', '.join(supported_kinds)}. Got {self._kind}." + ) + + if isinstance(self.data, ABCSeries): + raise NotImplementedError( + "An iterable subplots for a Series is not supported." + ) + + columns = self.data.columns + if isinstance(columns, ABCMultiIndex): + raise NotImplementedError( + "An iterable subplots for a DataFrame with a MultiIndex column " + "is not supported." + ) + + if columns.nunique() != len(columns): + raise NotImplementedError( + "An iterable subplots for a DataFrame with non-unique column " + "labels is not supported." + ) + + # subplots is a list of tuples where each tuple is a group of + # columns to be grouped together (one ax per group). + # we consolidate the subplots list such that: + # - the tuples contain indices instead of column names + # - the columns that aren't yet in the list are added in a group + # of their own. + # For example with columns from a to g, and + # subplots = [(a, c), (b, f, e)], + # we end up with [(ai, ci), (bi, fi, ei), (di,), (gi,)] + # This way, we can handle self.subplots in a homogeneous manner + # later. + # TODO: also accept indices instead of just names? + + out = [] + seen_columns: set[Hashable] = set() + for group in subplots: + if not is_list_like(group): + raise ValueError( + "When subplots is an iterable, each entry " + "should be a list/tuple of column names." + ) + idx_locs = columns.get_indexer_for(group) + if (idx_locs == -1).any(): + bad_labels = np.extract(idx_locs == -1, group) + raise ValueError( + f"Column label(s) {list(bad_labels)} not found in the DataFrame." + ) + else: + unique_columns = set(group) + duplicates = seen_columns.intersection(unique_columns) + if duplicates: + raise ValueError( + "Each column should be in only one subplot. " + f"Columns {duplicates} were found in multiple subplots." + ) + seen_columns = seen_columns.union(unique_columns) + out.append(tuple(idx_locs)) + + unseen_columns = columns.difference(seen_columns) + for column in unseen_columns: + idx_loc = columns.get_loc(column) + out.append((idx_loc,)) + return out + def _validate_color_args(self): if ( "color" in self.kwds @@ -371,8 +478,11 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num): def _setup_subplots(self): if self.subplots: + naxes = ( + self.nseries if isinstance(self.subplots, bool) else len(self.subplots) + ) fig, axes = create_subplots( - naxes=self.nseries, + naxes=naxes, sharex=self.sharex, sharey=self.sharey, figsize=self.figsize, @@ -784,9 +894,23 @@ def _get_ax_layer(cls, ax, primary=True): else: return getattr(ax, "right_ax", ax) + def _col_idx_to_axis_idx(self, col_idx: int) -> int: + """Return the index of the axis where the column at col_idx should be plotted""" + if isinstance(self.subplots, list): + # Subplots is a list: some columns will be grouped together in the same ax + return next( + group_idx + for (group_idx, group) in enumerate(self.subplots) + if col_idx in group + ) + else: + # subplots is True: one ax per column + return col_idx + def _get_ax(self, i: int): # get the twinx ax if appropriate if self.subplots: + i = self._col_idx_to_axis_idx(i) ax = self.axes[i] ax = self._maybe_right_yaxis(ax, i) self.axes[i] = ax diff --git a/pandas/tests/plotting/frame/test_frame.py b/pandas/tests/plotting/frame/test_frame.py index c4ce0b256cd41..3ec3744e43653 100644 --- a/pandas/tests/plotting/frame/test_frame.py +++ b/pandas/tests/plotting/frame/test_frame.py @@ -2071,6 +2071,90 @@ def test_plot_no_numeric_data(self): with pytest.raises(TypeError, match="no numeric data to plot"): df.plot() + @td.skip_if_no_scipy + @pytest.mark.parametrize( + "kind", ("line", "bar", "barh", "hist", "kde", "density", "area", "pie") + ) + def test_group_subplot(self, kind): + d = { + "a": np.arange(10), + "b": np.arange(10) + 1, + "c": np.arange(10) + 1, + "d": np.arange(10), + "e": np.arange(10), + } + df = DataFrame(d) + + axes = df.plot(subplots=[("b", "e"), ("c", "d")], kind=kind) + assert len(axes) == 3 # 2 groups + single column a + + expected_labels = (["b", "e"], ["c", "d"], ["a"]) + for ax, labels in zip(axes, expected_labels): + if kind != "pie": + self._check_legend_labels(ax, labels=labels) + if kind == "line": + assert len(ax.lines) == len(labels) + + def test_group_subplot_series_notimplemented(self): + ser = Series(range(1)) + msg = "An iterable subplots for a Series" + with pytest.raises(NotImplementedError, match=msg): + ser.plot(subplots=[("a",)]) + + def test_group_subplot_multiindex_notimplemented(self): + df = DataFrame(np.eye(2), columns=MultiIndex.from_tuples([(0, 1), (1, 2)])) + msg = "An iterable subplots for a DataFrame with a MultiIndex" + with pytest.raises(NotImplementedError, match=msg): + df.plot(subplots=[(0, 1)]) + + def test_group_subplot_nonunique_cols_notimplemented(self): + df = DataFrame(np.eye(2), columns=["a", "a"]) + msg = "An iterable subplots for a DataFrame with non-unique" + with pytest.raises(NotImplementedError, match=msg): + df.plot(subplots=[("a",)]) + + @pytest.mark.parametrize( + "subplots, expected_msg", + [ + (123, "subplots should be a bool or an iterable"), + ("a", "each entry should be a list/tuple"), # iterable of non-iterable + ((1,), "each entry should be a list/tuple"), # iterable of non-iterable + (("a",), "each entry should be a list/tuple"), # iterable of strings + ], + ) + def test_group_subplot_bad_input(self, subplots, expected_msg): + # Make sure error is raised when subplots is not a properly + # formatted iterable. Only iterables of iterables are permitted, and + # entries should not be strings. + d = {"a": np.arange(10), "b": np.arange(10)} + df = DataFrame(d) + + with pytest.raises(ValueError, match=expected_msg): + df.plot(subplots=subplots) + + def test_group_subplot_invalid_column_name(self): + d = {"a": np.arange(10), "b": np.arange(10)} + df = DataFrame(d) + + with pytest.raises(ValueError, match=r"Column label\(s\) \['bad_name'\]"): + df.plot(subplots=[("a", "bad_name")]) + + def test_group_subplot_duplicated_column(self): + d = {"a": np.arange(10), "b": np.arange(10), "c": np.arange(10)} + df = DataFrame(d) + + with pytest.raises(ValueError, match="should be in only one subplot"): + df.plot(subplots=[("a", "b"), ("a", "c")]) + + @pytest.mark.parametrize("kind", ("box", "scatter", "hexbin")) + def test_group_subplot_invalid_kind(self, kind): + d = {"a": np.arange(10), "b": np.arange(10)} + df = DataFrame(d) + with pytest.raises( + ValueError, match="When subplots is an iterable, kind must be one of" + ): + df.plot(subplots=[("a", "b")], kind=kind) + @pytest.mark.parametrize( "index_name, old_label, new_label", [