diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py index 26b25597ce1a6..4d643ffb734e4 100644 --- a/pandas/plotting/_matplotlib/tools.py +++ b/pandas/plotting/_matplotlib/tools.py @@ -1,6 +1,6 @@ # being a bit too dynamic from math import ceil -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Iterable, List, Sequence, Tuple, Union import warnings import matplotlib.table @@ -15,10 +15,13 @@ from pandas.plotting._matplotlib import compat if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.axis import Axis + from matplotlib.lines import Line2D # noqa:F401 from matplotlib.table import Table -def format_date_labels(ax, rot): +def format_date_labels(ax: "Axes", rot): # mini version of autofmt_xdate for label in ax.get_xticklabels(): label.set_ha("right") @@ -278,7 +281,7 @@ def _subplots( return fig, axes -def _remove_labels_from_axis(axis): +def _remove_labels_from_axis(axis: "Axis"): for t in axis.get_majorticklabels(): t.set_visible(False) @@ -294,7 +297,15 @@ def _remove_labels_from_axis(axis): axis.get_label().set_visible(False) -def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey): +def _handle_shared_axes( + axarr: Iterable["Axes"], + nplots: int, + naxes: int, + nrows: int, + ncols: int, + sharex: bool, + sharey: bool, +): if nplots > 1: if compat._mpl_ge_3_2_0(): row_num = lambda x: x.get_subplotspec().rowspan.start @@ -340,7 +351,7 @@ def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey): _remove_labels_from_axis(ax.yaxis) -def _flatten(axes): +def _flatten(axes: Union["Axes", Sequence["Axes"]]) -> Sequence["Axes"]: if not is_list_like(axes): return np.array([axes]) elif isinstance(axes, (np.ndarray, ABCIndexClass)): @@ -348,7 +359,13 @@ def _flatten(axes): return np.array(axes) -def _set_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None): +def _set_ticks_props( + axes: Union["Axes", Sequence["Axes"]], + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, +): import matplotlib.pyplot as plt for ax in _flatten(axes): @@ -363,7 +380,7 @@ def _set_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=Non return axes -def _get_all_lines(ax): +def _get_all_lines(ax: "Axes") -> List["Line2D"]: lines = ax.get_lines() if hasattr(ax, "right_ax"): @@ -375,7 +392,7 @@ def _get_all_lines(ax): return lines -def _get_xlim(lines) -> Tuple[float, float]: +def _get_xlim(lines: Iterable["Line2D"]) -> Tuple[float, float]: left, right = np.inf, -np.inf for l in lines: x = l.get_xdata(orig=False)