From 6b18aaa7739b9b7210f314d42261d57739e51578 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 4 Nov 2023 14:57:03 -0700 Subject: [PATCH 1/2] TYP: plotting --- pandas/plotting/_matplotlib/core.py | 45 +++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index d88605db60720..b95bb3b2a6853 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -12,6 +12,7 @@ from typing import ( TYPE_CHECKING, Literal, + final, ) import warnings @@ -272,6 +273,7 @@ def __init__( self._validate_color_args() + @final def _validate_subplots_kwarg( self, subplots: bool | Sequence[Sequence[str]] ) -> bool | list[tuple[int, ...]]: @@ -418,6 +420,7 @@ def _validate_color_args(self): "other or pass 'style' without a color symbol" ) + @final def _iter_data(self, data=None, keep_index: bool = False, fillna=None): if data is None: data = self.data @@ -443,9 +446,11 @@ def nseries(self) -> int: else: return self.data.shape[1] + @final def draw(self) -> None: self.plt.draw_if_interactive() + @final def generate(self) -> None: self._args_adjust() self._compute_plot_data() @@ -463,11 +468,13 @@ def generate(self) -> None: def _args_adjust(self) -> None: pass + @final def _has_plotted_object(self, ax: Axes) -> bool: """check whether ax has data""" return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0 - def _maybe_right_yaxis(self, ax: Axes, axes_num): + @final + def _maybe_right_yaxis(self, ax: Axes, axes_num: int): if not self.on_right(axes_num): # secondary axes may be passed via ax kw return self._get_ax_layer(ax) @@ -495,6 +502,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num): new_ax.set_yscale("symlog") return new_ax + @final def _setup_subplots(self): if self.subplots: naxes = ( @@ -565,6 +573,7 @@ def result(self): else: return self.axes[0] + @final def _convert_to_ndarray(self, data): # GH31357: categorical columns are processed separately if isinstance(data.dtype, CategoricalDtype): @@ -583,6 +592,7 @@ def _convert_to_ndarray(self, data): return data + @final def _compute_plot_data(self): data = self.data @@ -640,6 +650,7 @@ def _compute_plot_data(self): def _make_plot(self): raise AbstractMethodError(self) + @final def _add_table(self) -> None: if self.table is False: return @@ -650,6 +661,7 @@ def _add_table(self) -> None: ax = self._get_ax(0) tools.table(ax, data) + @final def _post_plot_logic_common(self, ax, data): """Common post process for each axes""" if self.orientation == "vertical" or self.orientation is None: @@ -672,6 +684,7 @@ def _post_plot_logic_common(self, ax, data): def _post_plot_logic(self, ax, data) -> None: """Post process for each axes. Overridden in child classes""" + @final def _adorn_subplots(self): """Common post process unrelated to data""" if len(self.axes) > 0: @@ -733,6 +746,7 @@ def _adorn_subplots(self): raise ValueError(msg) self.axes[0].set_title(self.title) + @final def _apply_axis_properties( self, axis: Axis, rot=None, fontsize: int | None = None ) -> None: @@ -762,6 +776,7 @@ def legend_title(self) -> str | None: stringified = map(pprint_thing, self.data.columns.names) return ",".join(stringified) + @final def _mark_right_label(self, label: str, index: int) -> str: """ Append ``(right)`` to the label of a line if it's plotted on the right axis. @@ -772,6 +787,7 @@ def _mark_right_label(self, label: str, index: int) -> str: label += " (right)" return label + @final def _append_legend_handles_labels(self, handle: Artist, label: str) -> None: """ Append current handle and label to ``legend_handles`` and ``legend_labels``. @@ -817,6 +833,7 @@ def _make_legend(self) -> None: if ax.get_visible(): ax.legend(loc="best") + @final def _get_ax_legend(self, ax: Axes): """ Take in axes and return ax and legend under different scenarios @@ -832,6 +849,7 @@ def _get_ax_legend(self, ax: Axes): ax = other_ax return ax, leg + @final @cache_readonly def plt(self): import matplotlib.pyplot as plt @@ -840,6 +858,7 @@ def plt(self): _need_to_set_index = False + @final def _get_xticks(self, convert_period: bool = False): index = self.data.index is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time") @@ -890,10 +909,12 @@ def _plot( args = (x, y, style) if style is not None else (x, y) return ax.plot(*args, **kwds) + @final def _get_custom_index_name(self): """Specify whether xlabel/ylabel should be used to override index name""" return self.xlabel + @final def _get_index_name(self) -> str | None: if isinstance(self.data.index, ABCMultiIndex): name = self.data.index.names @@ -913,6 +934,7 @@ def _get_index_name(self) -> str | None: return name + @final @classmethod def _get_ax_layer(cls, ax, primary: bool = True): """get left (primary) or right (secondary) axes""" @@ -921,6 +943,7 @@ def _get_ax_layer(cls, ax, primary: bool = True): else: return getattr(ax, "right_ax", ax) + @final def _col_idx_to_axis_idx(self, col_idx: int) -> int: """Return the index of the axis where the column at col_idx should be plotted""" if isinstance(self.subplots, list): @@ -934,6 +957,7 @@ def _col_idx_to_axis_idx(self, col_idx: int) -> int: # subplots is True: one ax per column return col_idx + @final def _get_ax(self, i: int): # get the twinx ax if appropriate if self.subplots: @@ -948,6 +972,7 @@ def _get_ax(self, i: int): ax.get_yaxis().set_visible(True) return ax + @final @classmethod def get_default_ax(cls, ax) -> None: import matplotlib.pyplot as plt @@ -957,13 +982,15 @@ def get_default_ax(cls, ax) -> None: ax = plt.gca() ax = cls._get_ax_layer(ax) - def on_right(self, i): + @final + def on_right(self, i: int): if isinstance(self.secondary_y, bool): return self.secondary_y if isinstance(self.secondary_y, (tuple, list, np.ndarray, ABCIndex)): return self.data.columns[i] in self.secondary_y + @final def _apply_style_colors(self, colors, kwds, col_num, label: str): """ Manage style and color based on column number and its label. @@ -1004,6 +1031,7 @@ def _get_colors( color=self.kwds.get(color_kwds), ) + @final def _parse_errorbars(self, label, err): """ Look for error keyword arguments and return the actual errorbar data @@ -1093,6 +1121,7 @@ def match_labels(data, e): return err + @final def _get_errorbars( self, label=None, index=None, xerr: bool = True, yerr: bool = True ): @@ -1114,6 +1143,7 @@ def _get_errorbars( errors[kw] = err return errors + @final def _get_subplots(self): from matplotlib.axes import Subplot @@ -1123,6 +1153,7 @@ def _get_subplots(self): if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None) ] + @final def _get_axes_layout(self) -> tuple[int, int]: axes = self._get_subplots() x_set = set() @@ -1161,10 +1192,12 @@ def __init__(self, data, x, y, **kwargs) -> None: self.x = x self.y = y + @final @property def nseries(self) -> int: return 1 + @final def _post_plot_logic(self, ax: Axes, data) -> None: x, y = self.x, self.y xlabel = self.xlabel if self.xlabel is not None else pprint_thing(x) @@ -1172,6 +1205,7 @@ def _post_plot_logic(self, ax: Axes, data) -> None: ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) + @final def _plot_colorbar(self, ax: Axes, **kwds): # Addresses issues #10611 and #10678: # When plotting scatterplots and hexbinplots in IPython @@ -1351,10 +1385,12 @@ def __init__(self, data, **kwargs) -> None: if "x_compat" in self.kwds: self.x_compat = bool(self.kwds.pop("x_compat")) + @final def _is_ts_plot(self) -> bool: # this is slightly deceptive return not self.x_compat and self.use_index and self._use_dynamic_x() + @final def _use_dynamic_x(self): return use_dynamic_x(self._get_ax(0), self.data) @@ -1422,6 +1458,7 @@ def _plot( # type: ignore[override] cls._update_stacker(ax, stacking_id, y) return lines + @final def _ts_plot(self, ax: Axes, x, data, style=None, **kwds): # accept x to be consistent with normal plot func, # x is not passed to tsplot as it uses data.index as x coordinate @@ -1442,12 +1479,14 @@ def _ts_plot(self, ax: Axes, x, data, style=None, **kwds): format_dateaxis(ax, ax.freq, data.index) return lines + @final def _get_stacking_id(self): if self.stacked: return id(self.data) else: return None + @final @classmethod def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None: if stacking_id is None: @@ -1459,6 +1498,7 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None: ax._stacker_pos_prior[stacking_id] = np.zeros(n) ax._stacker_neg_prior[stacking_id] = np.zeros(n) + @final @classmethod def _get_stacked_values(cls, ax: Axes, stacking_id, values, label): if stacking_id is None: @@ -1478,6 +1518,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label): f"Column '{label}' contains both positive and negative values" ) + @final @classmethod def _update_stacker(cls, ax: Axes, stacking_id, values) -> None: if stacking_id is None: From a8795dcec8c59454fb5dd456c5f1131574f8d2fd Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 4 Nov 2023 15:30:41 -0700 Subject: [PATCH 2/2] mypy fixup --- pandas/plotting/_matplotlib/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index b95bb3b2a6853..c02330ae4e452 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -909,7 +909,6 @@ def _plot( args = (x, y, style) if style is not None else (x, y) return ax.plot(*args, **kwds) - @final def _get_custom_index_name(self): """Specify whether xlabel/ylabel should be used to override index name""" return self.xlabel