Skip to content

Commit 97c61e8

Browse files
authored
REF: make plotting less stateful (4) (#55872)
* REF: make plotting less stateful (4) * REF: make plotting less stateful (4) * TYP: plotting * REF: make plotting less stateful (4)
1 parent 21fa354 commit 97c61e8

File tree

4 files changed

+50
-22
lines changed

4 files changed

+50
-22
lines changed

pandas/plotting/_matplotlib/boxplot.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from matplotlib.artist import setp
1111
import numpy as np
1212

13+
from pandas.util._decorators import cache_readonly
1314
from pandas.util._exceptions import find_stack_level
1415

1516
from pandas.core.dtypes.common import is_dict_like
@@ -132,15 +133,29 @@ def _validate_color_args(self):
132133
else:
133134
self.color = None
134135

136+
@cache_readonly
137+
def _color_attrs(self):
135138
# get standard colors for default
136-
colors = get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
137139
# use 2 colors by default, for box/whisker and median
138140
# flier colors isn't needed here
139141
# because it can be specified by ``sym`` kw
140-
self._boxes_c = colors[0]
141-
self._whiskers_c = colors[0]
142-
self._medians_c = colors[2]
143-
self._caps_c = colors[0]
142+
return get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
143+
144+
@cache_readonly
145+
def _boxes_c(self):
146+
return self._color_attrs[0]
147+
148+
@cache_readonly
149+
def _whiskers_c(self):
150+
return self._color_attrs[0]
151+
152+
@cache_readonly
153+
def _medians_c(self):
154+
return self._color_attrs[2]
155+
156+
@cache_readonly
157+
def _caps_c(self):
158+
return self._color_attrs[0]
144159

145160
def _get_colors(
146161
self,

pandas/plotting/_matplotlib/core.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@
9090
npt,
9191
)
9292

93+
from pandas import Series
94+
9395

9496
def _color_in_style(style: str) -> bool:
9597
"""
@@ -471,7 +473,8 @@ def generate(self) -> None:
471473
self._post_plot_logic(ax, self.data)
472474

473475
@final
474-
def _has_plotted_object(self, ax: Axes) -> bool:
476+
@staticmethod
477+
def _has_plotted_object(ax: Axes) -> bool:
475478
"""check whether ax has data"""
476479
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0
477480

@@ -576,7 +579,8 @@ def result(self):
576579
return self.axes[0]
577580

578581
@final
579-
def _convert_to_ndarray(self, data):
582+
@staticmethod
583+
def _convert_to_ndarray(data):
580584
# GH31357: categorical columns are processed separately
581585
if isinstance(data.dtype, CategoricalDtype):
582586
return data
@@ -767,6 +771,7 @@ def _apply_axis_properties(
767771
if fontsize is not None:
768772
label.set_fontsize(fontsize)
769773

774+
@final
770775
@property
771776
def legend_title(self) -> str | None:
772777
if not isinstance(self.data.columns, ABCMultiIndex):
@@ -836,7 +841,8 @@ def _make_legend(self) -> None:
836841
ax.legend(loc="best")
837842

838843
@final
839-
def _get_ax_legend(self, ax: Axes):
844+
@staticmethod
845+
def _get_ax_legend(ax: Axes):
840846
"""
841847
Take in axes and return ax and legend under different scenarios
842848
"""
@@ -1454,7 +1460,7 @@ def _plot( # type: ignore[override]
14541460
return lines
14551461

14561462
@final
1457-
def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
1463+
def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
14581464
# accept x to be consistent with normal plot func,
14591465
# x is not passed to tsplot as it uses data.index as x coordinate
14601466
# column_num must be in kwds for stacking purpose
@@ -1471,11 +1477,13 @@ def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
14711477

14721478
lines = self._plot(ax, data.index, data.values, style=style, **kwds)
14731479
# set date formatter, locators and rescale limits
1474-
format_dateaxis(ax, ax.freq, data.index)
1480+
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
1481+
# expected "DatetimeIndex | PeriodIndex"
1482+
format_dateaxis(ax, ax.freq, data.index) # type: ignore[arg-type]
14751483
return lines
14761484

14771485
@final
1478-
def _get_stacking_id(self):
1486+
def _get_stacking_id(self) -> int | None:
14791487
if self.stacked:
14801488
return id(self.data)
14811489
else:

pandas/plotting/_matplotlib/hist.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,16 @@ def __init__(
5656
data,
5757
bins: int | np.ndarray | list[np.ndarray] = 10,
5858
bottom: int | np.ndarray = 0,
59+
*,
60+
range=None,
5961
**kwargs,
6062
) -> None:
6163
if is_list_like(bottom):
6264
bottom = np.array(bottom)
6365
self.bottom = bottom
6466

67+
self._bin_range = range
68+
6569
self.xlabel = kwargs.get("xlabel")
6670
self.ylabel = kwargs.get("ylabel")
6771
# Do not call LinePlot.__init__ which may fill nan
@@ -85,7 +89,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
8589
values = np.ravel(nd_values)
8690
values = values[~isna(values)]
8791

88-
hist, bins = np.histogram(values, bins=bins, range=self.kwds.get("range", None))
92+
hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
8993
return bins
9094

9195
# error: Signature of "_plot" incompatible with supertype "LinePlot"
@@ -209,24 +213,23 @@ def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
209213
self.bw_method = bw_method
210214
self.ind = ind
211215

212-
def _get_ind(self, y):
213-
if self.ind is None:
216+
@staticmethod
217+
def _get_ind(y, ind):
218+
if ind is None:
214219
# np.nanmax() and np.nanmin() ignores the missing values
215220
sample_range = np.nanmax(y) - np.nanmin(y)
216221
ind = np.linspace(
217222
np.nanmin(y) - 0.5 * sample_range,
218223
np.nanmax(y) + 0.5 * sample_range,
219224
1000,
220225
)
221-
elif is_integer(self.ind):
226+
elif is_integer(ind):
222227
sample_range = np.nanmax(y) - np.nanmin(y)
223228
ind = np.linspace(
224229
np.nanmin(y) - 0.5 * sample_range,
225230
np.nanmax(y) + 0.5 * sample_range,
226-
self.ind,
231+
ind,
227232
)
228-
else:
229-
ind = self.ind
230233
return ind
231234

232235
@classmethod
@@ -252,7 +255,7 @@ def _plot(
252255

253256
def _make_plot_keywords(self, kwds, y):
254257
kwds["bw_method"] = self.bw_method
255-
kwds["ind"] = self._get_ind(y)
258+
kwds["ind"] = self._get_ind(y, ind=self.ind)
256259
return kwds
257260

258261
def _post_plot_logic(self, ax, data) -> None:

pandas/plotting/_matplotlib/timeseries.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
DataFrame,
4949
DatetimeIndex,
5050
Index,
51+
PeriodIndex,
5152
Series,
5253
)
5354

@@ -300,16 +301,17 @@ def maybe_convert_index(ax: Axes, data):
300301
return data
301302

302303

303-
# Patch methods for subplot. Only format_dateaxis is currently used.
304-
# Do we need the rest for convenience?
304+
# Patch methods for subplot.
305305

306306

307307
def _format_coord(freq, t, y) -> str:
308308
time_period = Period(ordinal=int(t), freq=freq)
309309
return f"t = {time_period} y = {y:8f}"
310310

311311

312-
def format_dateaxis(subplot, freq, index) -> None:
312+
def format_dateaxis(
313+
subplot, freq: BaseOffset, index: DatetimeIndex | PeriodIndex
314+
) -> None:
313315
"""
314316
Pretty-formats the date axis (x-axis).
315317

0 commit comments

Comments
 (0)