Skip to content

TYP: plotting #55887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@

from pandas._typing import IndexLabel

from pandas import DataFrame
from pandas import (
DataFrame,
Series,
)
from pandas.core.groupby.generic import DataFrameGroupBy


def hist_series(
self,
self: Series,
by=None,
ax=None,
grid: bool = True,
Expand Down Expand Up @@ -512,7 +516,7 @@ def boxplot(
@Substitution(data="", backend=_backend_doc)
@Appender(_boxplot_doc)
def boxplot_frame(
self,
self: DataFrame,
column=None,
by=None,
ax=None,
Expand Down Expand Up @@ -542,7 +546,7 @@ def boxplot_frame(


def boxplot_frame_groupby(
grouped,
grouped: DataFrameGroupBy,
subplots: bool = True,
column=None,
fontsize: int | None = None,
Expand Down Expand Up @@ -843,11 +847,11 @@ class PlotAccessor(PandasObject):
_kind_aliases = {"density": "kde"}
_all_kinds = _common_kinds + _series_kinds + _dataframe_kinds

def __init__(self, data) -> None:
def __init__(self, data: Series | DataFrame) -> None:
self._parent = data

@staticmethod
def _get_call_args(backend_name: str, data, args, kwargs):
def _get_call_args(backend_name: str, data: Series | DataFrame, args, kwargs):
"""
This function makes calls to this accessor `__call__` method compatible
with the previous `SeriesPlotMethods.__call__` and
Expand Down
16 changes: 8 additions & 8 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,18 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
cls, ax: Axes, y, column_num=None, return_type: str = "axes", **kwds
cls, ax: Axes, y: np.ndarray, column_num=None, return_type: str = "axes", **kwds
):
ys: np.ndarray | list[np.ndarray]
if y.ndim == 2:
y = [remove_na_arraylike(v) for v in y]
ys = [remove_na_arraylike(v) for v in y]
# Boxplot fails with empty arrays, so need to add a NaN
# if any cols are empty
# GH 8181
y = [v if v.size > 0 else np.array([np.nan]) for v in y]
ys = [v if v.size > 0 else np.array([np.nan]) for v in ys]
else:
y = remove_na_arraylike(y)
bp = ax.boxplot(y, **kwds)
ys = remove_na_arraylike(y)
bp = ax.boxplot(ys, **kwds)

if return_type == "dict":
return bp, bp
Expand Down Expand Up @@ -240,8 +241,7 @@ def _make_plot(self, fig: Figure) -> None:
self.maybe_color_bp(bp)
self._return_obj = ret

labels = [left for left, _ in self._iter_data()]
labels = [pprint_thing(left) for left in labels]
labels = [pprint_thing(left) for left in self.data.columns]
if not self.use_index:
labels = [pprint_thing(key) for key in range(len(labels))]
_set_ticklabels(
Expand All @@ -251,7 +251,7 @@ def _make_plot(self, fig: Figure) -> None:
def _make_legend(self) -> None:
pass

def _post_plot_logic(self, ax, data) -> None:
def _post_plot_logic(self, ax: Axes, data) -> None:
# GH 45465: make sure that the boxplot doesn't ignore xlabel/ylabel
if self.xlabel:
ax.set_xlabel(pprint_thing(self.xlabel))
Expand Down
97 changes: 62 additions & 35 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import (
Hashable,
Iterable,
Iterator,
Sequence,
)
from typing import (
Expand Down Expand Up @@ -431,17 +432,15 @@ def _validate_color_args(self):
)

@final
def _iter_data(self, data=None, keep_index: bool = False, fillna=None):
if data is None:
data = self.data
if fillna is not None:
data = data.fillna(fillna)

@staticmethod
def _iter_data(
data: DataFrame | dict[Hashable, Series | DataFrame]
) -> Iterator[tuple[Hashable, np.ndarray]]:
for col, values in data.items():
if keep_index is True:
yield col, values
else:
yield col, values.values
# This was originally written to use values.values before EAs
# were implemented; adding np.asarray(...) to keep consistent
# typing.
yield col, np.asarray(values.values)

@property
def nseries(self) -> int:
Expand Down Expand Up @@ -480,7 +479,7 @@ def _has_plotted_object(ax: Axes) -> bool:
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0

@final
def _maybe_right_yaxis(self, ax: Axes, axes_num: int):
def _maybe_right_yaxis(self, ax: Axes, axes_num: int) -> Axes:
if not self.on_right(axes_num):
# secondary axes may be passed via ax kw
return self._get_ax_layer(ax)
Expand Down Expand Up @@ -643,11 +642,7 @@ def _compute_plot_data(self):

numeric_data = data.select_dtypes(include=include_type, exclude=exclude_type)

try:
is_empty = numeric_data.columns.empty
except AttributeError:
is_empty = not len(numeric_data)

is_empty = numeric_data.shape[-1] == 0
# no non-numeric frames or series allowed
if is_empty:
raise TypeError("no numeric data to plot")
Expand All @@ -669,7 +664,7 @@ def _add_table(self) -> None:
tools.table(ax, data)

@final
def _post_plot_logic_common(self, ax, data):
def _post_plot_logic_common(self, ax: Axes, data) -> None:
"""Common post process for each axes"""
if self.orientation == "vertical" or self.orientation is None:
self._apply_axis_properties(ax.xaxis, rot=self.rot, fontsize=self.fontsize)
Expand All @@ -688,7 +683,7 @@ def _post_plot_logic_common(self, ax, data):
raise ValueError

@abstractmethod
def _post_plot_logic(self, ax, data) -> None:
def _post_plot_logic(self, ax: Axes, data) -> None:
"""Post process for each axes. Overridden in child classes"""

@final
Expand Down Expand Up @@ -1042,7 +1037,7 @@ def _get_colors(
)

@final
def _parse_errorbars(self, label, err):
def _parse_errorbars(self, label: str, err):
"""
Look for error keyword arguments and return the actual errorbar data
or return the error DataFrame/dict
Expand Down Expand Up @@ -1123,7 +1118,10 @@ def match_labels(data, e):
err = np.tile(err, (self.nseries, 1))

elif is_number(err):
err = np.tile([err], (self.nseries, len(self.data)))
err = np.tile(
[err], # pyright: ignore[reportGeneralTypeIssues]
(self.nseries, len(self.data)),
)

else:
msg = f"No valid {label} detected"
Expand Down Expand Up @@ -1404,14 +1402,14 @@ def _make_plot(self, fig: Figure) -> None:

x = data.index # dummy, not used
plotf = self._ts_plot
it = self._iter_data(data=data, keep_index=True)
it = data.items()
else:
x = self._get_xticks(convert_period=True)
# error: Incompatible types in assignment (expression has type
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
plotf = self._plot # type: ignore[assignment]
it = self._iter_data()
it = self._iter_data(data=self.data)

stacking_id = self._get_stacking_id()
is_errorbar = com.any_not_none(*self.errors.values())
Expand All @@ -1420,7 +1418,12 @@ def _make_plot(self, fig: Figure) -> None:
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)
kwds = self.kwds.copy()
style, kwds = self._apply_style_colors(colors, kwds, i, label)
style, kwds = self._apply_style_colors(
colors,
kwds,
i,
label, # pyright: ignore[reportGeneralTypeIssues]
)

errors = self._get_errorbars(label=label, index=i)
kwds = dict(kwds, **errors)
Expand All @@ -1432,7 +1435,7 @@ def _make_plot(self, fig: Figure) -> None:
newlines = plotf(
ax,
x,
y,
y, # pyright: ignore[reportGeneralTypeIssues]
style=style,
column_num=i,
stacking_id=stacking_id,
Expand All @@ -1451,7 +1454,14 @@ def _make_plot(self, fig: Figure) -> None:
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
cls, ax: Axes, x, y, style=None, column_num=None, stacking_id=None, **kwds
cls,
ax: Axes,
x,
y: np.ndarray,
style=None,
column_num=None,
stacking_id=None,
**kwds,
):
# column_num is used to get the target column from plotf in line and
# area plots
Expand All @@ -1478,7 +1488,7 @@ def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
decorate_axes(ax.right_ax, freq, kwds)
ax._plot_data.append((data, self._kind, kwds))

lines = self._plot(ax, data.index, data.values, style=style, **kwds)
lines = self._plot(ax, data.index, np.asarray(data.values), style=style, **kwds)
# set date formatter, locators and rescale limits
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
# expected "DatetimeIndex | PeriodIndex"
Expand Down Expand Up @@ -1506,7 +1516,9 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:

@final
@classmethod
def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
def _get_stacked_values(
cls, ax: Axes, stacking_id: int | None, values: np.ndarray, label
) -> np.ndarray:
if stacking_id is None:
return values
if not hasattr(ax, "_stacker_pos_prior"):
Expand All @@ -1526,7 +1538,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):

@final
@classmethod
def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None:
if stacking_id is None:
return
if (values >= 0).all():
Expand Down Expand Up @@ -1604,7 +1616,7 @@ def _plot( # type: ignore[override]
cls,
ax: Axes,
x,
y,
y: np.ndarray,
style=None,
column_num=None,
stacking_id=None,
Expand Down Expand Up @@ -1730,7 +1742,7 @@ def _plot( # type: ignore[override]
cls,
ax: Axes,
x,
y,
y: np.ndarray,
w,
start: int | npt.NDArray[np.intp] = 0,
log: bool = False,
Expand All @@ -1749,7 +1761,8 @@ def _make_plot(self, fig: Figure) -> None:
pos_prior = neg_prior = np.zeros(len(self.data))
K = self.nseries

for i, (label, y) in enumerate(self._iter_data(fillna=0)):
data = self.data.fillna(0)
for i, (label, y) in enumerate(self._iter_data(data=data)):
ax = self._get_ax(i)
kwds = self.kwds.copy()
if self._is_series:
Expand Down Expand Up @@ -1828,7 +1841,14 @@ def _post_plot_logic(self, ax: Axes, data) -> None:

self._decorate_ticks(ax, self._get_index_name(), str_index, s_edge, e_edge)

def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
def _decorate_ticks(
self,
ax: Axes,
name: str | None,
ticklabels: list[str],
start_edge: float,
end_edge: float,
) -> None:
ax.set_xlim((start_edge, end_edge))

if self.xticks is not None:
Expand Down Expand Up @@ -1862,7 +1882,7 @@ def _plot( # type: ignore[override]
cls,
ax: Axes,
x,
y,
y: np.ndarray,
w,
start: int | npt.NDArray[np.intp] = 0,
log: bool = False,
Expand All @@ -1873,7 +1893,14 @@ def _plot( # type: ignore[override]
def _get_custom_index_name(self):
return self.ylabel

def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
def _decorate_ticks(
self,
ax: Axes,
name: str | None,
ticklabels: list[str],
start_edge: float,
end_edge: float,
) -> None:
# horizontal bars
ax.set_ylim((start_edge, end_edge))
ax.set_yticks(self.tick_pos)
Expand Down Expand Up @@ -1907,7 +1934,7 @@ def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
self.kwds.setdefault("colors", colors)

for i, (label, y) in enumerate(self._iter_data()):
for i, (label, y) in enumerate(self._iter_data(data=self.data)):
ax = self._get_ax(i)
if label is not None:
label = pprint_thing(label)
Expand Down
8 changes: 4 additions & 4 deletions pandas/plotting/_matplotlib/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from pandas.plotting._matplotlib.misc import unpack_single_str_list

if TYPE_CHECKING:
from collections.abc import Hashable

from pandas._typing import IndexLabel


def create_iter_data_given_by(
data: DataFrame, kind: str = "hist"
) -> dict[str, DataFrame | Series]:
) -> dict[Hashable, DataFrame | Series]:
"""
Create data for iteration given `by` is assigned or not, and it is only
used in both hist and boxplot.
Expand Down Expand Up @@ -126,9 +128,7 @@ def reconstruct_data_with_by(
return data


def reformat_hist_y_given_by(
y: Series | np.ndarray, by: IndexLabel | None
) -> Series | np.ndarray:
def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray:
"""Internal function to reformat y given `by` is applied or not for hist plot.

If by is None, input y is 1-d with NaN removed; and if by is not None, groupby
Expand Down
Loading