Skip to content

Commit a41b545

Browse files
authored
REF: Mess with data less outside of MPLPlot.__init__ (#55889)
* REF: Mess with data less outside of MPLPlot.__init__ * lint fixup * pyright ignore * pyright ignore
1 parent 5cedf87 commit a41b545

File tree

2 files changed

+61
-42
lines changed

2 files changed

+61
-42
lines changed

pandas/plotting/_matplotlib/core.py

+56-40
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,15 @@
8989

9090
from pandas._typing import (
9191
IndexLabel,
92+
NDFrameT,
9293
PlottingOrientation,
9394
npt,
9495
)
9596

96-
from pandas import Series
97+
from pandas import (
98+
PeriodIndex,
99+
Series,
100+
)
97101

98102

99103
def _color_in_style(style: str) -> bool:
@@ -161,8 +165,6 @@ def __init__(
161165
) -> None:
162166
import matplotlib.pyplot as plt
163167

164-
self.data = data
165-
166168
# if users assign an empty list or tuple, raise `ValueError`
167169
# similar to current `df.box` and `df.hist` APIs.
168170
if by in ([], ()):
@@ -193,9 +195,11 @@ def __init__(
193195

194196
self.kind = kind
195197

196-
self.subplots = self._validate_subplots_kwarg(subplots)
198+
self.subplots = type(self)._validate_subplots_kwarg(
199+
subplots, data, kind=self._kind
200+
)
197201

198-
self.sharex = self._validate_sharex(sharex, ax, by)
202+
self.sharex = type(self)._validate_sharex(sharex, ax, by)
199203
self.sharey = sharey
200204
self.figsize = figsize
201205
self.layout = layout
@@ -245,10 +249,11 @@ def __init__(
245249
# parse errorbar input if given
246250
xerr = kwds.pop("xerr", None)
247251
yerr = kwds.pop("yerr", None)
248-
self.errors = {
249-
kw: self._parse_errorbars(kw, err)
250-
for kw, err in zip(["xerr", "yerr"], [xerr, yerr])
251-
}
252+
nseries = self._get_nseries(data)
253+
xerr, data = type(self)._parse_errorbars("xerr", xerr, data, nseries)
254+
yerr, data = type(self)._parse_errorbars("yerr", yerr, data, nseries)
255+
self.errors = {"xerr": xerr, "yerr": yerr}
256+
self.data = data
252257

253258
if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)):
254259
secondary_y = [secondary_y]
@@ -271,7 +276,8 @@ def __init__(
271276
self._validate_color_args()
272277

273278
@final
274-
def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
279+
@staticmethod
280+
def _validate_sharex(sharex: bool | None, ax, by) -> bool:
275281
if sharex is None:
276282
# if by is defined, subplots are used and sharex should be False
277283
if ax is None and by is None: # pylint: disable=simplifiable-if-statement
@@ -285,8 +291,9 @@ def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
285291
return bool(sharex)
286292

287293
@final
294+
@staticmethod
288295
def _validate_subplots_kwarg(
289-
self, subplots: bool | Sequence[Sequence[str]]
296+
subplots: bool | Sequence[Sequence[str]], data: Series | DataFrame, kind: str
290297
) -> bool | list[tuple[int, ...]]:
291298
"""
292299
Validate the subplots parameter
@@ -323,18 +330,18 @@ def _validate_subplots_kwarg(
323330
"area",
324331
"pie",
325332
)
326-
if self._kind not in supported_kinds:
333+
if kind not in supported_kinds:
327334
raise ValueError(
328335
"When subplots is an iterable, kind must be "
329-
f"one of {', '.join(supported_kinds)}. Got {self._kind}."
336+
f"one of {', '.join(supported_kinds)}. Got {kind}."
330337
)
331338

332-
if isinstance(self.data, ABCSeries):
339+
if isinstance(data, ABCSeries):
333340
raise NotImplementedError(
334341
"An iterable subplots for a Series is not supported."
335342
)
336343

337-
columns = self.data.columns
344+
columns = data.columns
338345
if isinstance(columns, ABCMultiIndex):
339346
raise NotImplementedError(
340347
"An iterable subplots for a DataFrame with a MultiIndex column "
@@ -442,18 +449,22 @@ def _iter_data(
442449
# typing.
443450
yield col, np.asarray(values.values)
444451

445-
@property
446-
def nseries(self) -> int:
452+
def _get_nseries(self, data: Series | DataFrame) -> int:
447453
# When `by` is explicitly assigned, grouped data size will be defined, and
448454
# this will determine number of subplots to have, aka `self.nseries`
449-
if self.data.ndim == 1:
455+
if data.ndim == 1:
450456
return 1
451457
elif self.by is not None and self._kind == "hist":
452458
return len(self._grouped)
453459
elif self.by is not None and self._kind == "box":
454460
return len(self.columns)
455461
else:
456-
return self.data.shape[1]
462+
return data.shape[1]
463+
464+
@final
465+
@property
466+
def nseries(self) -> int:
467+
return self._get_nseries(self.data)
457468

458469
@final
459470
def draw(self) -> None:
@@ -880,10 +891,12 @@ def _get_xticks(self, convert_period: bool = False):
880891
index = self.data.index
881892
is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")
882893

894+
x: list[int] | np.ndarray
883895
if self.use_index:
884896
if convert_period and isinstance(index, ABCPeriodIndex):
885897
self.data = self.data.reindex(index=index.sort_values())
886-
x = self.data.index.to_timestamp()._mpl_repr()
898+
index = cast("PeriodIndex", self.data.index)
899+
x = index.to_timestamp()._mpl_repr()
887900
elif is_any_real_numeric_dtype(index.dtype):
888901
# Matplotlib supports numeric values or datetime objects as
889902
# xaxis values. Taking LBYL approach here, by the time
@@ -1050,8 +1063,12 @@ def _get_colors(
10501063
color=self.kwds.get(color_kwds),
10511064
)
10521065

1066+
# TODO: tighter typing for first return?
10531067
@final
1054-
def _parse_errorbars(self, label: str, err):
1068+
@staticmethod
1069+
def _parse_errorbars(
1070+
label: str, err, data: NDFrameT, nseries: int
1071+
) -> tuple[Any, NDFrameT]:
10551072
"""
10561073
Look for error keyword arguments and return the actual errorbar data
10571074
or return the error DataFrame/dict
@@ -1071,32 +1088,32 @@ def _parse_errorbars(self, label: str, err):
10711088
should be in a ``Mx2xN`` array.
10721089
"""
10731090
if err is None:
1074-
return None
1091+
return None, data
10751092

10761093
def match_labels(data, e):
10771094
e = e.reindex(data.index)
10781095
return e
10791096

10801097
# key-matched DataFrame
10811098
if isinstance(err, ABCDataFrame):
1082-
err = match_labels(self.data, err)
1099+
err = match_labels(data, err)
10831100
# key-matched dict
10841101
elif isinstance(err, dict):
10851102
pass
10861103

10871104
# Series of error values
10881105
elif isinstance(err, ABCSeries):
10891106
# broadcast error series across data
1090-
err = match_labels(self.data, err)
1107+
err = match_labels(data, err)
10911108
err = np.atleast_2d(err)
1092-
err = np.tile(err, (self.nseries, 1))
1109+
err = np.tile(err, (nseries, 1))
10931110

10941111
# errors are a column in the dataframe
10951112
elif isinstance(err, str):
1096-
evalues = self.data[err].values
1097-
self.data = self.data[self.data.columns.drop(err)]
1113+
evalues = data[err].values
1114+
data = data[data.columns.drop(err)]
10981115
err = np.atleast_2d(evalues)
1099-
err = np.tile(err, (self.nseries, 1))
1116+
err = np.tile(err, (nseries, 1))
11001117

11011118
elif is_list_like(err):
11021119
if is_iterator(err):
@@ -1108,40 +1125,40 @@ def match_labels(data, e):
11081125
err_shape = err.shape
11091126

11101127
# asymmetrical error bars
1111-
if isinstance(self.data, ABCSeries) and err_shape[0] == 2:
1128+
if isinstance(data, ABCSeries) and err_shape[0] == 2:
11121129
err = np.expand_dims(err, 0)
11131130
err_shape = err.shape
1114-
if err_shape[2] != len(self.data):
1131+
if err_shape[2] != len(data):
11151132
raise ValueError(
11161133
"Asymmetrical error bars should be provided "
1117-
f"with the shape (2, {len(self.data)})"
1134+
f"with the shape (2, {len(data)})"
11181135
)
1119-
elif isinstance(self.data, ABCDataFrame) and err.ndim == 3:
1136+
elif isinstance(data, ABCDataFrame) and err.ndim == 3:
11201137
if (
1121-
(err_shape[0] != self.nseries)
1138+
(err_shape[0] != nseries)
11221139
or (err_shape[1] != 2)
1123-
or (err_shape[2] != len(self.data))
1140+
or (err_shape[2] != len(data))
11241141
):
11251142
raise ValueError(
11261143
"Asymmetrical error bars should be provided "
1127-
f"with the shape ({self.nseries}, 2, {len(self.data)})"
1144+
f"with the shape ({nseries}, 2, {len(data)})"
11281145
)
11291146

11301147
# broadcast errors to each data series
11311148
if len(err) == 1:
1132-
err = np.tile(err, (self.nseries, 1))
1149+
err = np.tile(err, (nseries, 1))
11331150

11341151
elif is_number(err):
11351152
err = np.tile(
11361153
[err], # pyright: ignore[reportGeneralTypeIssues]
1137-
(self.nseries, len(self.data)),
1154+
(nseries, len(data)),
11381155
)
11391156

11401157
else:
11411158
msg = f"No valid {label} detected"
11421159
raise ValueError(msg)
11431160

1144-
return err
1161+
return err, data # pyright: ignore[reportGeneralTypeIssues]
11451162

11461163
@final
11471164
def _get_errorbars(
@@ -1215,8 +1232,7 @@ def __init__(self, data, x, y, **kwargs) -> None:
12151232
self.y = y
12161233

12171234
@final
1218-
@property
1219-
def nseries(self) -> int:
1235+
def _get_nseries(self, data: Series | DataFrame) -> int:
12201236
return 1
12211237

12221238
@final

pandas/plotting/_matplotlib/hist.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545

4646
from pandas._typing import PlottingOrientation
4747

48-
from pandas import DataFrame
48+
from pandas import (
49+
DataFrame,
50+
Series,
51+
)
4952

5053

5154
class HistPlot(LinePlot):
@@ -87,7 +90,7 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
8790
bins = self._calculate_bins(self.data, bins)
8891
return bins
8992

90-
def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
93+
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
9194
"""Calculate bins given data"""
9295
nd_values = data.infer_objects(copy=False)._get_numeric_data()
9396
values = np.ravel(nd_values)

0 commit comments

Comments
 (0)