Skip to content

REF: make plotting less stateful (4) #55872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from matplotlib.artist import setp
import numpy as np

from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import is_dict_like
Expand Down Expand Up @@ -132,15 +133,29 @@ def _validate_color_args(self):
else:
self.color = None

@cache_readonly
def _color_attrs(self):
# get standard colors for default
colors = get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
# use 2 colors by default, for box/whisker and median
# flier colors isn't needed here
# because it can be specified by ``sym`` kw
self._boxes_c = colors[0]
self._whiskers_c = colors[0]
self._medians_c = colors[2]
self._caps_c = colors[0]
return get_standard_colors(num_colors=3, colormap=self.colormap, color=None)

@cache_readonly
def _boxes_c(self):
return self._color_attrs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can de-duplicate this in a follow up since 3 variables caches the same self._color_attrs[0]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. there will be a few more passes in this less-stateful push


@cache_readonly
def _whiskers_c(self):
return self._color_attrs[0]

@cache_readonly
def _medians_c(self):
return self._color_attrs[2]

@cache_readonly
def _caps_c(self):
return self._color_attrs[0]

def _get_colors(
self,
Expand Down
20 changes: 14 additions & 6 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
npt,
)

from pandas import Series


def _color_in_style(style: str) -> bool:
"""
Expand Down Expand Up @@ -471,7 +473,8 @@ def generate(self) -> None:
self._post_plot_logic(ax, self.data)

@final
def _has_plotted_object(self, ax: Axes) -> bool:
@staticmethod
def _has_plotted_object(ax: Axes) -> bool:
"""check whether ax has data"""
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0

Expand Down Expand Up @@ -576,7 +579,8 @@ def result(self):
return self.axes[0]

@final
def _convert_to_ndarray(self, data):
@staticmethod
def _convert_to_ndarray(data):
# GH31357: categorical columns are processed separately
if isinstance(data.dtype, CategoricalDtype):
return data
Expand Down Expand Up @@ -767,6 +771,7 @@ def _apply_axis_properties(
if fontsize is not None:
label.set_fontsize(fontsize)

@final
@property
def legend_title(self) -> str | None:
if not isinstance(self.data.columns, ABCMultiIndex):
Expand Down Expand Up @@ -836,7 +841,8 @@ def _make_legend(self) -> None:
ax.legend(loc="best")

@final
def _get_ax_legend(self, ax: Axes):
@staticmethod
def _get_ax_legend(ax: Axes):
"""
Take in axes and return ax and legend under different scenarios
"""
Expand Down Expand Up @@ -1454,7 +1460,7 @@ def _plot( # type: ignore[override]
return lines

@final
def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
# accept x to be consistent with normal plot func,
# x is not passed to tsplot as it uses data.index as x coordinate
# column_num must be in kwds for stacking purpose
Expand All @@ -1471,11 +1477,13 @@ def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):

lines = self._plot(ax, data.index, data.values, style=style, **kwds)
# set date formatter, locators and rescale limits
format_dateaxis(ax, ax.freq, data.index)
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
# expected "DatetimeIndex | PeriodIndex"
format_dateaxis(ax, ax.freq, data.index) # type: ignore[arg-type]
return lines

@final
def _get_stacking_id(self):
def _get_stacking_id(self) -> int | None:
if self.stacked:
return id(self.data)
else:
Expand Down
19 changes: 11 additions & 8 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@ def __init__(
data,
bins: int | np.ndarray | list[np.ndarray] = 10,
bottom: int | np.ndarray = 0,
*,
range=None,
**kwargs,
) -> None:
if is_list_like(bottom):
bottom = np.array(bottom)
self.bottom = bottom

self._bin_range = range

self.xlabel = kwargs.get("xlabel")
self.ylabel = kwargs.get("ylabel")
# Do not call LinePlot.__init__ which may fill nan
Expand All @@ -85,7 +89,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
values = np.ravel(nd_values)
values = values[~isna(values)]

hist, bins = np.histogram(values, bins=bins, range=self.kwds.get("range", None))
hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
return bins

# error: Signature of "_plot" incompatible with supertype "LinePlot"
Expand Down Expand Up @@ -209,24 +213,23 @@ def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
self.bw_method = bw_method
self.ind = ind

def _get_ind(self, y):
if self.ind is None:
@staticmethod
def _get_ind(y, ind):
if ind is None:
# np.nanmax() and np.nanmin() ignores the missing values
sample_range = np.nanmax(y) - np.nanmin(y)
ind = np.linspace(
np.nanmin(y) - 0.5 * sample_range,
np.nanmax(y) + 0.5 * sample_range,
1000,
)
elif is_integer(self.ind):
elif is_integer(ind):
sample_range = np.nanmax(y) - np.nanmin(y)
ind = np.linspace(
np.nanmin(y) - 0.5 * sample_range,
np.nanmax(y) + 0.5 * sample_range,
self.ind,
ind,
)
else:
ind = self.ind
return ind

@classmethod
Expand All @@ -252,7 +255,7 @@ def _plot(

def _make_plot_keywords(self, kwds, y):
kwds["bw_method"] = self.bw_method
kwds["ind"] = self._get_ind(y)
kwds["ind"] = self._get_ind(y, ind=self.ind)
return kwds

def _post_plot_logic(self, ax, data) -> None:
Expand Down
8 changes: 5 additions & 3 deletions pandas/plotting/_matplotlib/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
DataFrame,
DatetimeIndex,
Index,
PeriodIndex,
Series,
)

Expand Down Expand Up @@ -300,16 +301,17 @@ def maybe_convert_index(ax: Axes, data):
return data


# Patch methods for subplot. Only format_dateaxis is currently used.
# Do we need the rest for convenience?
# Patch methods for subplot.


def _format_coord(freq, t, y) -> str:
time_period = Period(ordinal=int(t), freq=freq)
return f"t = {time_period} y = {y:8f}"


def format_dateaxis(subplot, freq, index) -> None:
def format_dateaxis(
subplot, freq: BaseOffset, index: DatetimeIndex | PeriodIndex
) -> None:
"""
Pretty-formats the date axis (x-axis).

Expand Down