|
3 | 3 | from typing import (
|
4 | 4 | TYPE_CHECKING,
|
5 | 5 | Hashable,
|
| 6 | + Iterable, |
| 7 | + Sequence, |
6 | 8 | )
|
7 | 9 | import warnings
|
8 | 10 |
|
@@ -102,7 +104,7 @@ def __init__(
|
102 | 104 | data,
|
103 | 105 | kind=None,
|
104 | 106 | by: IndexLabel | None = None,
|
105 |
| - subplots=False, |
| 107 | + subplots: bool | Sequence[Sequence[str]] = False, |
106 | 108 | sharex=None,
|
107 | 109 | sharey=False,
|
108 | 110 | use_index=True,
|
@@ -166,8 +168,7 @@ def __init__(
|
166 | 168 | self.kind = kind
|
167 | 169 |
|
168 | 170 | self.sort_columns = sort_columns
|
169 |
| - |
170 |
| - self.subplots = subplots |
| 171 | + self.subplots = self._validate_subplots_kwarg(subplots) |
171 | 172 |
|
172 | 173 | if sharex is None:
|
173 | 174 |
|
@@ -253,6 +254,112 @@ def __init__(
|
253 | 254 |
|
254 | 255 | self._validate_color_args()
|
255 | 256 |
|
| 257 | + def _validate_subplots_kwarg( |
| 258 | + self, subplots: bool | Sequence[Sequence[str]] |
| 259 | + ) -> bool | list[tuple[int, ...]]: |
| 260 | + """ |
| 261 | + Validate the subplots parameter |
| 262 | +
|
| 263 | + - check type and content |
| 264 | + - check for duplicate columns |
| 265 | + - check for invalid column names |
| 266 | + - convert column names into indices |
| 267 | + - add missing columns in a group of their own |
| 268 | + See comments in code below for more details. |
| 269 | +
|
| 270 | + Parameters |
| 271 | + ---------- |
| 272 | + subplots : subplots parameters as passed to PlotAccessor |
| 273 | +
|
| 274 | + Returns |
| 275 | + ------- |
| 276 | + validated subplots : a bool or a list of tuples of column indices. Columns |
| 277 | + in the same tuple will be grouped together in the resulting plot. |
| 278 | + """ |
| 279 | + |
| 280 | + if isinstance(subplots, bool): |
| 281 | + return subplots |
| 282 | + elif not isinstance(subplots, Iterable): |
| 283 | + raise ValueError("subplots should be a bool or an iterable") |
| 284 | + |
| 285 | + supported_kinds = ( |
| 286 | + "line", |
| 287 | + "bar", |
| 288 | + "barh", |
| 289 | + "hist", |
| 290 | + "kde", |
| 291 | + "density", |
| 292 | + "area", |
| 293 | + "pie", |
| 294 | + ) |
| 295 | + if self._kind not in supported_kinds: |
| 296 | + raise ValueError( |
| 297 | + "When subplots is an iterable, kind must be " |
| 298 | + f"one of {', '.join(supported_kinds)}. Got {self._kind}." |
| 299 | + ) |
| 300 | + |
| 301 | + if isinstance(self.data, ABCSeries): |
| 302 | + raise NotImplementedError( |
| 303 | + "An iterable subplots for a Series is not supported." |
| 304 | + ) |
| 305 | + |
| 306 | + columns = self.data.columns |
| 307 | + if isinstance(columns, ABCMultiIndex): |
| 308 | + raise NotImplementedError( |
| 309 | + "An iterable subplots for a DataFrame with a MultiIndex column " |
| 310 | + "is not supported." |
| 311 | + ) |
| 312 | + |
| 313 | + if columns.nunique() != len(columns): |
| 314 | + raise NotImplementedError( |
| 315 | + "An iterable subplots for a DataFrame with non-unique column " |
| 316 | + "labels is not supported." |
| 317 | + ) |
| 318 | + |
| 319 | + # subplots is a list of tuples where each tuple is a group of |
| 320 | + # columns to be grouped together (one ax per group). |
| 321 | + # we consolidate the subplots list such that: |
| 322 | + # - the tuples contain indices instead of column names |
| 323 | + # - the columns that aren't yet in the list are added in a group |
| 324 | + # of their own. |
| 325 | + # For example with columns from a to g, and |
| 326 | + # subplots = [(a, c), (b, f, e)], |
| 327 | + # we end up with [(ai, ci), (bi, fi, ei), (di,), (gi,)] |
| 328 | + # This way, we can handle self.subplots in a homogeneous manner |
| 329 | + # later. |
| 330 | + # TODO: also accept indices instead of just names? |
| 331 | + |
| 332 | + out = [] |
| 333 | + seen_columns: set[Hashable] = set() |
| 334 | + for group in subplots: |
| 335 | + if not is_list_like(group): |
| 336 | + raise ValueError( |
| 337 | + "When subplots is an iterable, each entry " |
| 338 | + "should be a list/tuple of column names." |
| 339 | + ) |
| 340 | + idx_locs = columns.get_indexer_for(group) |
| 341 | + if (idx_locs == -1).any(): |
| 342 | + bad_labels = np.extract(idx_locs == -1, group) |
| 343 | + raise ValueError( |
| 344 | + f"Column label(s) {list(bad_labels)} not found in the DataFrame." |
| 345 | + ) |
| 346 | + else: |
| 347 | + unique_columns = set(group) |
| 348 | + duplicates = seen_columns.intersection(unique_columns) |
| 349 | + if duplicates: |
| 350 | + raise ValueError( |
| 351 | + "Each column should be in only one subplot. " |
| 352 | + f"Columns {duplicates} were found in multiple subplots." |
| 353 | + ) |
| 354 | + seen_columns = seen_columns.union(unique_columns) |
| 355 | + out.append(tuple(idx_locs)) |
| 356 | + |
| 357 | + unseen_columns = columns.difference(seen_columns) |
| 358 | + for column in unseen_columns: |
| 359 | + idx_loc = columns.get_loc(column) |
| 360 | + out.append((idx_loc,)) |
| 361 | + return out |
| 362 | + |
256 | 363 | def _validate_color_args(self):
|
257 | 364 | if (
|
258 | 365 | "color" in self.kwds
|
@@ -371,8 +478,11 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
|
371 | 478 |
|
372 | 479 | def _setup_subplots(self):
|
373 | 480 | if self.subplots:
|
| 481 | + naxes = ( |
| 482 | + self.nseries if isinstance(self.subplots, bool) else len(self.subplots) |
| 483 | + ) |
374 | 484 | fig, axes = create_subplots(
|
375 |
| - naxes=self.nseries, |
| 485 | + naxes=naxes, |
376 | 486 | sharex=self.sharex,
|
377 | 487 | sharey=self.sharey,
|
378 | 488 | figsize=self.figsize,
|
@@ -784,9 +894,23 @@ def _get_ax_layer(cls, ax, primary=True):
|
784 | 894 | else:
|
785 | 895 | return getattr(ax, "right_ax", ax)
|
786 | 896 |
|
| 897 | + def _col_idx_to_axis_idx(self, col_idx: int) -> int: |
| 898 | + """Return the index of the axis where the column at col_idx should be plotted""" |
| 899 | + if isinstance(self.subplots, list): |
| 900 | + # Subplots is a list: some columns will be grouped together in the same ax |
| 901 | + return next( |
| 902 | + group_idx |
| 903 | + for (group_idx, group) in enumerate(self.subplots) |
| 904 | + if col_idx in group |
| 905 | + ) |
| 906 | + else: |
| 907 | + # subplots is True: one ax per column |
| 908 | + return col_idx |
| 909 | + |
787 | 910 | def _get_ax(self, i: int):
|
788 | 911 | # get the twinx ax if appropriate
|
789 | 912 | if self.subplots:
|
| 913 | + i = self._col_idx_to_axis_idx(i) |
790 | 914 | ax = self.axes[i]
|
791 | 915 | ax = self._maybe_right_yaxis(ax, i)
|
792 | 916 | self.axes[i] = ax
|
|
0 commit comments