diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py index 3b60e95b9dceb..a017787f2dc2d 100644 --- a/pandas/plotting/_core.py +++ b/pandas/plotting/_core.py @@ -37,11 +37,15 @@ from pandas._typing import IndexLabel - from pandas import DataFrame + from pandas import ( + DataFrame, + Series, + ) + from pandas.core.groupby.generic import DataFrameGroupBy def hist_series( - self, + self: Series, by=None, ax=None, grid: bool = True, @@ -512,7 +516,7 @@ def boxplot( @Substitution(data="", backend=_backend_doc) @Appender(_boxplot_doc) def boxplot_frame( - self, + self: DataFrame, column=None, by=None, ax=None, @@ -542,7 +546,7 @@ def boxplot_frame( def boxplot_frame_groupby( - grouped, + grouped: DataFrameGroupBy, subplots: bool = True, column=None, fontsize: int | None = None, @@ -843,11 +847,11 @@ class PlotAccessor(PandasObject): _kind_aliases = {"density": "kde"} _all_kinds = _common_kinds + _series_kinds + _dataframe_kinds - def __init__(self, data) -> None: + def __init__(self, data: Series | DataFrame) -> None: self._parent = data @staticmethod - def _get_call_args(backend_name: str, data, args, kwargs): + def _get_call_args(backend_name: str, data: Series | DataFrame, args, kwargs): """ This function makes calls to this accessor `__call__` method compatible with the previous `SeriesPlotMethods.__call__` and diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index e6481aab50f6e..37ebd940f3646 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -93,17 +93,18 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None: # error: Signature of "_plot" incompatible with supertype "MPLPlot" @classmethod def _plot( # type: ignore[override] - cls, ax: Axes, y, column_num=None, return_type: str = "axes", **kwds + cls, ax: Axes, y: np.ndarray, column_num=None, return_type: str = "axes", **kwds ): + ys: np.ndarray | list[np.ndarray] if y.ndim == 2: - y = [remove_na_arraylike(v) for v in y] + ys = [remove_na_arraylike(v) for v in y] # Boxplot fails with empty arrays, so need to add a NaN # if any cols are empty # GH 8181 - y = [v if v.size > 0 else np.array([np.nan]) for v in y] + ys = [v if v.size > 0 else np.array([np.nan]) for v in ys] else: - y = remove_na_arraylike(y) - bp = ax.boxplot(y, **kwds) + ys = remove_na_arraylike(y) + bp = ax.boxplot(ys, **kwds) if return_type == "dict": return bp, bp @@ -240,8 +241,7 @@ def _make_plot(self, fig: Figure) -> None: self.maybe_color_bp(bp) self._return_obj = ret - labels = [left for left, _ in self._iter_data()] - labels = [pprint_thing(left) for left in labels] + labels = [pprint_thing(left) for left in self.data.columns] if not self.use_index: labels = [pprint_thing(key) for key in range(len(labels))] _set_ticklabels( @@ -251,7 +251,7 @@ def _make_plot(self, fig: Figure) -> None: def _make_legend(self) -> None: pass - def _post_plot_logic(self, ax, data) -> None: + def _post_plot_logic(self, ax: Axes, data) -> None: # GH 45465: make sure that the boxplot doesn't ignore xlabel/ylabel if self.xlabel: ax.set_xlabel(pprint_thing(self.xlabel)) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index b67a8186c8c2b..c2b0f43d83e60 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -7,6 +7,7 @@ from collections.abc import ( Hashable, Iterable, + Iterator, Sequence, ) from typing import ( @@ -431,17 +432,15 @@ def _validate_color_args(self): ) @final - def _iter_data(self, data=None, keep_index: bool = False, fillna=None): - if data is None: - data = self.data - if fillna is not None: - data = data.fillna(fillna) - + @staticmethod + def _iter_data( + data: DataFrame | dict[Hashable, Series | DataFrame] + ) -> Iterator[tuple[Hashable, np.ndarray]]: for col, values in data.items(): - if keep_index is True: - yield col, values - else: - yield col, values.values + # This was originally written to use values.values before EAs + # were implemented; adding np.asarray(...) to keep consistent + # typing. + yield col, np.asarray(values.values) @property def nseries(self) -> int: @@ -480,7 +479,7 @@ def _has_plotted_object(ax: Axes) -> bool: return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0 @final - def _maybe_right_yaxis(self, ax: Axes, axes_num: int): + def _maybe_right_yaxis(self, ax: Axes, axes_num: int) -> Axes: if not self.on_right(axes_num): # secondary axes may be passed via ax kw return self._get_ax_layer(ax) @@ -643,11 +642,7 @@ def _compute_plot_data(self): numeric_data = data.select_dtypes(include=include_type, exclude=exclude_type) - try: - is_empty = numeric_data.columns.empty - except AttributeError: - is_empty = not len(numeric_data) - + is_empty = numeric_data.shape[-1] == 0 # no non-numeric frames or series allowed if is_empty: raise TypeError("no numeric data to plot") @@ -669,7 +664,7 @@ def _add_table(self) -> None: tools.table(ax, data) @final - def _post_plot_logic_common(self, ax, data): + def _post_plot_logic_common(self, ax: Axes, data) -> None: """Common post process for each axes""" if self.orientation == "vertical" or self.orientation is None: self._apply_axis_properties(ax.xaxis, rot=self.rot, fontsize=self.fontsize) @@ -688,7 +683,7 @@ def _post_plot_logic_common(self, ax, data): raise ValueError @abstractmethod - def _post_plot_logic(self, ax, data) -> None: + def _post_plot_logic(self, ax: Axes, data) -> None: """Post process for each axes. Overridden in child classes""" @final @@ -1042,7 +1037,7 @@ def _get_colors( ) @final - def _parse_errorbars(self, label, err): + def _parse_errorbars(self, label: str, err): """ Look for error keyword arguments and return the actual errorbar data or return the error DataFrame/dict @@ -1123,7 +1118,10 @@ def match_labels(data, e): err = np.tile(err, (self.nseries, 1)) elif is_number(err): - err = np.tile([err], (self.nseries, len(self.data))) + err = np.tile( + [err], # pyright: ignore[reportGeneralTypeIssues] + (self.nseries, len(self.data)), + ) else: msg = f"No valid {label} detected" @@ -1404,14 +1402,14 @@ def _make_plot(self, fig: Figure) -> None: x = data.index # dummy, not used plotf = self._ts_plot - it = self._iter_data(data=data, keep_index=True) + it = data.items() else: x = self._get_xticks(convert_period=True) # error: Incompatible types in assignment (expression has type # "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has # type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]") plotf = self._plot # type: ignore[assignment] - it = self._iter_data() + it = self._iter_data(data=self.data) stacking_id = self._get_stacking_id() is_errorbar = com.any_not_none(*self.errors.values()) @@ -1420,7 +1418,12 @@ def _make_plot(self, fig: Figure) -> None: for i, (label, y) in enumerate(it): ax = self._get_ax(i) kwds = self.kwds.copy() - style, kwds = self._apply_style_colors(colors, kwds, i, label) + style, kwds = self._apply_style_colors( + colors, + kwds, + i, + label, # pyright: ignore[reportGeneralTypeIssues] + ) errors = self._get_errorbars(label=label, index=i) kwds = dict(kwds, **errors) @@ -1432,7 +1435,7 @@ def _make_plot(self, fig: Figure) -> None: newlines = plotf( ax, x, - y, + y, # pyright: ignore[reportGeneralTypeIssues] style=style, column_num=i, stacking_id=stacking_id, @@ -1451,7 +1454,14 @@ def _make_plot(self, fig: Figure) -> None: # error: Signature of "_plot" incompatible with supertype "MPLPlot" @classmethod def _plot( # type: ignore[override] - cls, ax: Axes, x, y, style=None, column_num=None, stacking_id=None, **kwds + cls, + ax: Axes, + x, + y: np.ndarray, + style=None, + column_num=None, + stacking_id=None, + **kwds, ): # column_num is used to get the target column from plotf in line and # area plots @@ -1478,7 +1488,7 @@ def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds): decorate_axes(ax.right_ax, freq, kwds) ax._plot_data.append((data, self._kind, kwds)) - lines = self._plot(ax, data.index, data.values, style=style, **kwds) + lines = self._plot(ax, data.index, np.asarray(data.values), style=style, **kwds) # set date formatter, locators and rescale limits # error: Argument 3 to "format_dateaxis" has incompatible type "Index"; # expected "DatetimeIndex | PeriodIndex" @@ -1506,7 +1516,9 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None: @final @classmethod - def _get_stacked_values(cls, ax: Axes, stacking_id, values, label): + def _get_stacked_values( + cls, ax: Axes, stacking_id: int | None, values: np.ndarray, label + ) -> np.ndarray: if stacking_id is None: return values if not hasattr(ax, "_stacker_pos_prior"): @@ -1526,7 +1538,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label): @final @classmethod - def _update_stacker(cls, ax: Axes, stacking_id, values) -> None: + def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None: if stacking_id is None: return if (values >= 0).all(): @@ -1604,7 +1616,7 @@ def _plot( # type: ignore[override] cls, ax: Axes, x, - y, + y: np.ndarray, style=None, column_num=None, stacking_id=None, @@ -1730,7 +1742,7 @@ def _plot( # type: ignore[override] cls, ax: Axes, x, - y, + y: np.ndarray, w, start: int | npt.NDArray[np.intp] = 0, log: bool = False, @@ -1749,7 +1761,8 @@ def _make_plot(self, fig: Figure) -> None: pos_prior = neg_prior = np.zeros(len(self.data)) K = self.nseries - for i, (label, y) in enumerate(self._iter_data(fillna=0)): + data = self.data.fillna(0) + for i, (label, y) in enumerate(self._iter_data(data=data)): ax = self._get_ax(i) kwds = self.kwds.copy() if self._is_series: @@ -1828,7 +1841,14 @@ def _post_plot_logic(self, ax: Axes, data) -> None: self._decorate_ticks(ax, self._get_index_name(), str_index, s_edge, e_edge) - def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None: + def _decorate_ticks( + self, + ax: Axes, + name: str | None, + ticklabels: list[str], + start_edge: float, + end_edge: float, + ) -> None: ax.set_xlim((start_edge, end_edge)) if self.xticks is not None: @@ -1862,7 +1882,7 @@ def _plot( # type: ignore[override] cls, ax: Axes, x, - y, + y: np.ndarray, w, start: int | npt.NDArray[np.intp] = 0, log: bool = False, @@ -1873,7 +1893,14 @@ def _plot( # type: ignore[override] def _get_custom_index_name(self): return self.ylabel - def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None: + def _decorate_ticks( + self, + ax: Axes, + name: str | None, + ticklabels: list[str], + start_edge: float, + end_edge: float, + ) -> None: # horizontal bars ax.set_ylim((start_edge, end_edge)) ax.set_yticks(self.tick_pos) @@ -1907,7 +1934,7 @@ def _make_plot(self, fig: Figure) -> None: colors = self._get_colors(num_colors=len(self.data), color_kwds="colors") self.kwds.setdefault("colors", colors) - for i, (label, y) in enumerate(self._iter_data()): + for i, (label, y) in enumerate(self._iter_data(data=self.data)): ax = self._get_ax(i) if label is not None: label = pprint_thing(label) diff --git a/pandas/plotting/_matplotlib/groupby.py b/pandas/plotting/_matplotlib/groupby.py index 00c3c4114d377..d32741a1d1803 100644 --- a/pandas/plotting/_matplotlib/groupby.py +++ b/pandas/plotting/_matplotlib/groupby.py @@ -16,12 +16,14 @@ from pandas.plotting._matplotlib.misc import unpack_single_str_list if TYPE_CHECKING: + from collections.abc import Hashable + from pandas._typing import IndexLabel def create_iter_data_given_by( data: DataFrame, kind: str = "hist" -) -> dict[str, DataFrame | Series]: +) -> dict[Hashable, DataFrame | Series]: """ Create data for iteration given `by` is assigned or not, and it is only used in both hist and boxplot. @@ -126,9 +128,7 @@ def reconstruct_data_with_by( return data -def reformat_hist_y_given_by( - y: Series | np.ndarray, by: IndexLabel | None -) -> Series | np.ndarray: +def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray: """Internal function to reformat y given `by` is applied or not for hist plot. If by is None, input y is 1-d with NaN removed; and if by is not None, groupby diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index fd0dde40c0ab3..ed93dc740e1b4 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -101,7 +101,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray: def _plot( # type: ignore[override] cls, ax: Axes, - y, + y: np.ndarray, style=None, bottom: int | np.ndarray = 0, column_num: int = 0, @@ -166,7 +166,7 @@ def _make_plot(self, fig: Figure) -> None: self._append_legend_handles_labels(artists[0], label) - def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None: + def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None: """merge BoxPlot/KdePlot properties to passed kwds""" # y is required for KdePlot kwds["bottom"] = self.bottom @@ -225,7 +225,7 @@ def __init__( self.weights = weights @staticmethod - def _get_ind(y, ind): + def _get_ind(y: np.ndarray, ind): if ind is None: # np.nanmax() and np.nanmin() ignores the missing values sample_range = np.nanmax(y) - np.nanmin(y) @@ -248,12 +248,12 @@ def _get_ind(y, ind): def _plot( # type: ignore[override] cls, ax: Axes, - y, + y: np.ndarray, style=None, bw_method=None, ind=None, column_num=None, - stacking_id=None, + stacking_id: int | None = None, **kwds, ): from scipy.stats import gaussian_kde @@ -265,11 +265,11 @@ def _plot( # type: ignore[override] lines = MPLPlot._plot(ax, ind, y, style=style, **kwds) return lines - def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None: + def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None: kwds["bw_method"] = self.bw_method kwds["ind"] = self._get_ind(y, ind=self.ind) - def _post_plot_logic(self, ax, data) -> None: + def _post_plot_logic(self, ax: Axes, data) -> None: ax.set_ylabel("Density") diff --git a/pandas/plotting/_matplotlib/timeseries.py b/pandas/plotting/_matplotlib/timeseries.py index e3e02c69a26db..714ee9d162a25 100644 --- a/pandas/plotting/_matplotlib/timeseries.py +++ b/pandas/plotting/_matplotlib/timeseries.py @@ -5,6 +5,7 @@ import functools from typing import ( TYPE_CHECKING, + Any, cast, ) import warnings @@ -44,6 +45,8 @@ from matplotlib.axes import Axes + from pandas._typing import NDFrameT + from pandas import ( DataFrame, DatetimeIndex, @@ -56,7 +59,7 @@ # Plotting functions and monkey patches -def maybe_resample(series: Series, ax: Axes, kwargs): +def maybe_resample(series: Series, ax: Axes, kwargs: dict[str, Any]): # resample against axes freq if necessary freq, ax_freq = _get_freq(ax, series) @@ -99,7 +102,7 @@ def _is_sup(f1: str, f2: str) -> bool: ) -def _upsample_others(ax: Axes, freq, kwargs) -> None: +def _upsample_others(ax: Axes, freq: BaseOffset, kwargs: dict[str, Any]) -> None: legend = ax.get_legend() lines, labels = _replot_ax(ax, freq, kwargs) _replot_ax(ax, freq, kwargs) @@ -122,7 +125,7 @@ def _upsample_others(ax: Axes, freq, kwargs) -> None: ax.legend(lines, labels, loc="best", title=title) -def _replot_ax(ax: Axes, freq, kwargs): +def _replot_ax(ax: Axes, freq: BaseOffset, kwargs: dict[str, Any]): data = getattr(ax, "_plot_data", None) # clear current axes and data @@ -152,7 +155,7 @@ def _replot_ax(ax: Axes, freq, kwargs): return lines, labels -def decorate_axes(ax: Axes, freq, kwargs) -> None: +def decorate_axes(ax: Axes, freq: BaseOffset, kwargs: dict[str, Any]) -> None: """Initialize axes for time-series plotting""" if not hasattr(ax, "_plot_data"): ax._plot_data = [] @@ -264,7 +267,7 @@ def _get_index_freq(index: Index) -> BaseOffset | None: return freq -def maybe_convert_index(ax: Axes, data): +def maybe_convert_index(ax: Axes, data: NDFrameT) -> NDFrameT: # tsplot converts automatically, but don't want to convert index # over and over for DataFrames if isinstance(data.index, (ABCDatetimeIndex, ABCPeriodIndex)):