Skip to content

TYP: plotting #55829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
TYPE_CHECKING,
Literal,
final,
)
import warnings

Expand Down Expand Up @@ -272,6 +273,7 @@ def __init__(

self._validate_color_args()

@final
def _validate_subplots_kwarg(
self, subplots: bool | Sequence[Sequence[str]]
) -> bool | list[tuple[int, ...]]:
Expand Down Expand Up @@ -418,6 +420,7 @@ def _validate_color_args(self):
"other or pass 'style' without a color symbol"
)

@final
def _iter_data(self, data=None, keep_index: bool = False, fillna=None):
if data is None:
data = self.data
Expand All @@ -443,9 +446,11 @@ def nseries(self) -> int:
else:
return self.data.shape[1]

@final
def draw(self) -> None:
self.plt.draw_if_interactive()

@final
def generate(self) -> None:
self._args_adjust()
self._compute_plot_data()
Expand All @@ -463,11 +468,13 @@ def generate(self) -> None:
def _args_adjust(self) -> None:
pass

@final
def _has_plotted_object(self, ax: Axes) -> bool:
"""check whether ax has data"""
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0

def _maybe_right_yaxis(self, ax: Axes, axes_num):
@final
def _maybe_right_yaxis(self, ax: Axes, axes_num: int):
if not self.on_right(axes_num):
# secondary axes may be passed via ax kw
return self._get_ax_layer(ax)
Expand Down Expand Up @@ -495,6 +502,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
new_ax.set_yscale("symlog")
return new_ax

@final
def _setup_subplots(self):
if self.subplots:
naxes = (
Expand Down Expand Up @@ -565,6 +573,7 @@ def result(self):
else:
return self.axes[0]

@final
def _convert_to_ndarray(self, data):
# GH31357: categorical columns are processed separately
if isinstance(data.dtype, CategoricalDtype):
Expand All @@ -583,6 +592,7 @@ def _convert_to_ndarray(self, data):

return data

@final
def _compute_plot_data(self):
data = self.data

Expand Down Expand Up @@ -640,6 +650,7 @@ def _compute_plot_data(self):
def _make_plot(self):
raise AbstractMethodError(self)

@final
def _add_table(self) -> None:
if self.table is False:
return
Expand All @@ -650,6 +661,7 @@ def _add_table(self) -> None:
ax = self._get_ax(0)
tools.table(ax, data)

@final
def _post_plot_logic_common(self, ax, data):
"""Common post process for each axes"""
if self.orientation == "vertical" or self.orientation is None:
Expand All @@ -672,6 +684,7 @@ def _post_plot_logic_common(self, ax, data):
def _post_plot_logic(self, ax, data) -> None:
"""Post process for each axes. Overridden in child classes"""

@final
def _adorn_subplots(self):
"""Common post process unrelated to data"""
if len(self.axes) > 0:
Expand Down Expand Up @@ -733,6 +746,7 @@ def _adorn_subplots(self):
raise ValueError(msg)
self.axes[0].set_title(self.title)

@final
def _apply_axis_properties(
self, axis: Axis, rot=None, fontsize: int | None = None
) -> None:
Expand Down Expand Up @@ -762,6 +776,7 @@ def legend_title(self) -> str | None:
stringified = map(pprint_thing, self.data.columns.names)
return ",".join(stringified)

@final
def _mark_right_label(self, label: str, index: int) -> str:
"""
Append ``(right)`` to the label of a line if it's plotted on the right axis.
Expand All @@ -772,6 +787,7 @@ def _mark_right_label(self, label: str, index: int) -> str:
label += " (right)"
return label

@final
def _append_legend_handles_labels(self, handle: Artist, label: str) -> None:
"""
Append current handle and label to ``legend_handles`` and ``legend_labels``.
Expand Down Expand Up @@ -817,6 +833,7 @@ def _make_legend(self) -> None:
if ax.get_visible():
ax.legend(loc="best")

@final
def _get_ax_legend(self, ax: Axes):
"""
Take in axes and return ax and legend under different scenarios
Expand All @@ -832,6 +849,7 @@ def _get_ax_legend(self, ax: Axes):
ax = other_ax
return ax, leg

@final
@cache_readonly
def plt(self):
import matplotlib.pyplot as plt
Expand All @@ -840,6 +858,7 @@ def plt(self):

_need_to_set_index = False

@final
def _get_xticks(self, convert_period: bool = False):
index = self.data.index
is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")
Expand Down Expand Up @@ -894,6 +913,7 @@ def _get_custom_index_name(self):
"""Specify whether xlabel/ylabel should be used to override index name"""
return self.xlabel

@final
def _get_index_name(self) -> str | None:
if isinstance(self.data.index, ABCMultiIndex):
name = self.data.index.names
Expand All @@ -913,6 +933,7 @@ def _get_index_name(self) -> str | None:

return name

@final
@classmethod
def _get_ax_layer(cls, ax, primary: bool = True):
"""get left (primary) or right (secondary) axes"""
Expand All @@ -921,6 +942,7 @@ def _get_ax_layer(cls, ax, primary: bool = True):
else:
return getattr(ax, "right_ax", ax)

@final
def _col_idx_to_axis_idx(self, col_idx: int) -> int:
"""Return the index of the axis where the column at col_idx should be plotted"""
if isinstance(self.subplots, list):
Expand All @@ -934,6 +956,7 @@ def _col_idx_to_axis_idx(self, col_idx: int) -> int:
# subplots is True: one ax per column
return col_idx

@final
def _get_ax(self, i: int):
# get the twinx ax if appropriate
if self.subplots:
Expand All @@ -948,6 +971,7 @@ def _get_ax(self, i: int):
ax.get_yaxis().set_visible(True)
return ax

@final
@classmethod
def get_default_ax(cls, ax) -> None:
import matplotlib.pyplot as plt
Expand All @@ -957,13 +981,15 @@ def get_default_ax(cls, ax) -> None:
ax = plt.gca()
ax = cls._get_ax_layer(ax)

def on_right(self, i):
@final
def on_right(self, i: int):
if isinstance(self.secondary_y, bool):
return self.secondary_y

if isinstance(self.secondary_y, (tuple, list, np.ndarray, ABCIndex)):
return self.data.columns[i] in self.secondary_y

@final
def _apply_style_colors(self, colors, kwds, col_num, label: str):
"""
Manage style and color based on column number and its label.
Expand Down Expand Up @@ -1004,6 +1030,7 @@ def _get_colors(
color=self.kwds.get(color_kwds),
)

@final
def _parse_errorbars(self, label, err):
"""
Look for error keyword arguments and return the actual errorbar data
Expand Down Expand Up @@ -1093,6 +1120,7 @@ def match_labels(data, e):

return err

@final
def _get_errorbars(
self, label=None, index=None, xerr: bool = True, yerr: bool = True
):
Expand All @@ -1114,6 +1142,7 @@ def _get_errorbars(
errors[kw] = err
return errors

@final
def _get_subplots(self):
from matplotlib.axes import Subplot

Expand All @@ -1123,6 +1152,7 @@ def _get_subplots(self):
if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
]

@final
def _get_axes_layout(self) -> tuple[int, int]:
axes = self._get_subplots()
x_set = set()
Expand Down Expand Up @@ -1161,17 +1191,20 @@ def __init__(self, data, x, y, **kwargs) -> None:
self.x = x
self.y = y

@final
@property
def nseries(self) -> int:
return 1

@final
def _post_plot_logic(self, ax: Axes, data) -> None:
x, y = self.x, self.y
xlabel = self.xlabel if self.xlabel is not None else pprint_thing(x)
ylabel = self.ylabel if self.ylabel is not None else pprint_thing(y)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

@final
def _plot_colorbar(self, ax: Axes, **kwds):
# Addresses issues #10611 and #10678:
# When plotting scatterplots and hexbinplots in IPython
Expand Down Expand Up @@ -1351,10 +1384,12 @@ def __init__(self, data, **kwargs) -> None:
if "x_compat" in self.kwds:
self.x_compat = bool(self.kwds.pop("x_compat"))

@final
def _is_ts_plot(self) -> bool:
# this is slightly deceptive
return not self.x_compat and self.use_index and self._use_dynamic_x()

@final
def _use_dynamic_x(self):
return use_dynamic_x(self._get_ax(0), self.data)

Expand Down Expand Up @@ -1422,6 +1457,7 @@ def _plot( # type: ignore[override]
cls._update_stacker(ax, stacking_id, y)
return lines

@final
def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
# accept x to be consistent with normal plot func,
# x is not passed to tsplot as it uses data.index as x coordinate
Expand All @@ -1442,12 +1478,14 @@ def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
format_dateaxis(ax, ax.freq, data.index)
return lines

@final
def _get_stacking_id(self):
if self.stacked:
return id(self.data)
else:
return None

@final
@classmethod
def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
if stacking_id is None:
Expand All @@ -1459,6 +1497,7 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
ax._stacker_pos_prior[stacking_id] = np.zeros(n)
ax._stacker_neg_prior[stacking_id] = np.zeros(n)

@final
@classmethod
def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
if stacking_id is None:
Expand All @@ -1478,6 +1517,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
f"Column '{label}' contains both positive and negative values"
)

@final
@classmethod
def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
if stacking_id is None:
Expand Down