89
89
90
90
from pandas ._typing import (
91
91
IndexLabel ,
92
+ NDFrameT ,
92
93
PlottingOrientation ,
93
94
npt ,
94
95
)
95
96
96
- from pandas import Series
97
+ from pandas import (
98
+ PeriodIndex ,
99
+ Series ,
100
+ )
97
101
98
102
99
103
def _color_in_style (style : str ) -> bool :
@@ -161,8 +165,6 @@ def __init__(
161
165
) -> None :
162
166
import matplotlib .pyplot as plt
163
167
164
- self .data = data
165
-
166
168
# if users assign an empty list or tuple, raise `ValueError`
167
169
# similar to current `df.box` and `df.hist` APIs.
168
170
if by in ([], ()):
@@ -193,9 +195,11 @@ def __init__(
193
195
194
196
self .kind = kind
195
197
196
- self .subplots = self ._validate_subplots_kwarg (subplots )
198
+ self .subplots = type (self )._validate_subplots_kwarg (
199
+ subplots , data , kind = self ._kind
200
+ )
197
201
198
- self .sharex = self ._validate_sharex (sharex , ax , by )
202
+ self .sharex = type ( self ) ._validate_sharex (sharex , ax , by )
199
203
self .sharey = sharey
200
204
self .figsize = figsize
201
205
self .layout = layout
@@ -245,10 +249,11 @@ def __init__(
245
249
# parse errorbar input if given
246
250
xerr = kwds .pop ("xerr" , None )
247
251
yerr = kwds .pop ("yerr" , None )
248
- self .errors = {
249
- kw : self ._parse_errorbars (kw , err )
250
- for kw , err in zip (["xerr" , "yerr" ], [xerr , yerr ])
251
- }
252
+ nseries = self ._get_nseries (data )
253
+ xerr , data = type (self )._parse_errorbars ("xerr" , xerr , data , nseries )
254
+ yerr , data = type (self )._parse_errorbars ("yerr" , yerr , data , nseries )
255
+ self .errors = {"xerr" : xerr , "yerr" : yerr }
256
+ self .data = data
252
257
253
258
if not isinstance (secondary_y , (bool , tuple , list , np .ndarray , ABCIndex )):
254
259
secondary_y = [secondary_y ]
@@ -271,7 +276,8 @@ def __init__(
271
276
self ._validate_color_args ()
272
277
273
278
@final
274
- def _validate_sharex (self , sharex : bool | None , ax , by ) -> bool :
279
+ @staticmethod
280
+ def _validate_sharex (sharex : bool | None , ax , by ) -> bool :
275
281
if sharex is None :
276
282
# if by is defined, subplots are used and sharex should be False
277
283
if ax is None and by is None : # pylint: disable=simplifiable-if-statement
@@ -285,8 +291,9 @@ def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
285
291
return bool (sharex )
286
292
287
293
@final
294
+ @staticmethod
288
295
def _validate_subplots_kwarg (
289
- self , subplots : bool | Sequence [Sequence [str ]]
296
+ subplots : bool | Sequence [Sequence [str ]], data : Series | DataFrame , kind : str
290
297
) -> bool | list [tuple [int , ...]]:
291
298
"""
292
299
Validate the subplots parameter
@@ -323,18 +330,18 @@ def _validate_subplots_kwarg(
323
330
"area" ,
324
331
"pie" ,
325
332
)
326
- if self . _kind not in supported_kinds :
333
+ if kind not in supported_kinds :
327
334
raise ValueError (
328
335
"When subplots is an iterable, kind must be "
329
- f"one of { ', ' .join (supported_kinds )} . Got { self . _kind } ."
336
+ f"one of { ', ' .join (supported_kinds )} . Got { kind } ."
330
337
)
331
338
332
- if isinstance (self . data , ABCSeries ):
339
+ if isinstance (data , ABCSeries ):
333
340
raise NotImplementedError (
334
341
"An iterable subplots for a Series is not supported."
335
342
)
336
343
337
- columns = self . data .columns
344
+ columns = data .columns
338
345
if isinstance (columns , ABCMultiIndex ):
339
346
raise NotImplementedError (
340
347
"An iterable subplots for a DataFrame with a MultiIndex column "
@@ -442,18 +449,22 @@ def _iter_data(
442
449
# typing.
443
450
yield col , np .asarray (values .values )
444
451
445
- @property
446
- def nseries (self ) -> int :
452
+ def _get_nseries (self , data : Series | DataFrame ) -> int :
447
453
# When `by` is explicitly assigned, grouped data size will be defined, and
448
454
# this will determine number of subplots to have, aka `self.nseries`
449
- if self . data .ndim == 1 :
455
+ if data .ndim == 1 :
450
456
return 1
451
457
elif self .by is not None and self ._kind == "hist" :
452
458
return len (self ._grouped )
453
459
elif self .by is not None and self ._kind == "box" :
454
460
return len (self .columns )
455
461
else :
456
- return self .data .shape [1 ]
462
+ return data .shape [1 ]
463
+
464
+ @final
465
+ @property
466
+ def nseries (self ) -> int :
467
+ return self ._get_nseries (self .data )
457
468
458
469
@final
459
470
def draw (self ) -> None :
@@ -880,10 +891,12 @@ def _get_xticks(self, convert_period: bool = False):
880
891
index = self .data .index
881
892
is_datetype = index .inferred_type in ("datetime" , "date" , "datetime64" , "time" )
882
893
894
+ x : list [int ] | np .ndarray
883
895
if self .use_index :
884
896
if convert_period and isinstance (index , ABCPeriodIndex ):
885
897
self .data = self .data .reindex (index = index .sort_values ())
886
- x = self .data .index .to_timestamp ()._mpl_repr ()
898
+ index = cast ("PeriodIndex" , self .data .index )
899
+ x = index .to_timestamp ()._mpl_repr ()
887
900
elif is_any_real_numeric_dtype (index .dtype ):
888
901
# Matplotlib supports numeric values or datetime objects as
889
902
# xaxis values. Taking LBYL approach here, by the time
@@ -1050,8 +1063,12 @@ def _get_colors(
1050
1063
color = self .kwds .get (color_kwds ),
1051
1064
)
1052
1065
1066
+ # TODO: tighter typing for first return?
1053
1067
@final
1054
- def _parse_errorbars (self , label : str , err ):
1068
+ @staticmethod
1069
+ def _parse_errorbars (
1070
+ label : str , err , data : NDFrameT , nseries : int
1071
+ ) -> tuple [Any , NDFrameT ]:
1055
1072
"""
1056
1073
Look for error keyword arguments and return the actual errorbar data
1057
1074
or return the error DataFrame/dict
@@ -1071,32 +1088,32 @@ def _parse_errorbars(self, label: str, err):
1071
1088
should be in a ``Mx2xN`` array.
1072
1089
"""
1073
1090
if err is None :
1074
- return None
1091
+ return None , data
1075
1092
1076
1093
def match_labels (data , e ):
1077
1094
e = e .reindex (data .index )
1078
1095
return e
1079
1096
1080
1097
# key-matched DataFrame
1081
1098
if isinstance (err , ABCDataFrame ):
1082
- err = match_labels (self . data , err )
1099
+ err = match_labels (data , err )
1083
1100
# key-matched dict
1084
1101
elif isinstance (err , dict ):
1085
1102
pass
1086
1103
1087
1104
# Series of error values
1088
1105
elif isinstance (err , ABCSeries ):
1089
1106
# broadcast error series across data
1090
- err = match_labels (self . data , err )
1107
+ err = match_labels (data , err )
1091
1108
err = np .atleast_2d (err )
1092
- err = np .tile (err , (self . nseries , 1 ))
1109
+ err = np .tile (err , (nseries , 1 ))
1093
1110
1094
1111
# errors are a column in the dataframe
1095
1112
elif isinstance (err , str ):
1096
- evalues = self . data [err ].values
1097
- self . data = self . data [self . data .columns .drop (err )]
1113
+ evalues = data [err ].values
1114
+ data = data [data .columns .drop (err )]
1098
1115
err = np .atleast_2d (evalues )
1099
- err = np .tile (err , (self . nseries , 1 ))
1116
+ err = np .tile (err , (nseries , 1 ))
1100
1117
1101
1118
elif is_list_like (err ):
1102
1119
if is_iterator (err ):
@@ -1108,40 +1125,40 @@ def match_labels(data, e):
1108
1125
err_shape = err .shape
1109
1126
1110
1127
# asymmetrical error bars
1111
- if isinstance (self . data , ABCSeries ) and err_shape [0 ] == 2 :
1128
+ if isinstance (data , ABCSeries ) and err_shape [0 ] == 2 :
1112
1129
err = np .expand_dims (err , 0 )
1113
1130
err_shape = err .shape
1114
- if err_shape [2 ] != len (self . data ):
1131
+ if err_shape [2 ] != len (data ):
1115
1132
raise ValueError (
1116
1133
"Asymmetrical error bars should be provided "
1117
- f"with the shape (2, { len (self . data )} )"
1134
+ f"with the shape (2, { len (data )} )"
1118
1135
)
1119
- elif isinstance (self . data , ABCDataFrame ) and err .ndim == 3 :
1136
+ elif isinstance (data , ABCDataFrame ) and err .ndim == 3 :
1120
1137
if (
1121
- (err_shape [0 ] != self . nseries )
1138
+ (err_shape [0 ] != nseries )
1122
1139
or (err_shape [1 ] != 2 )
1123
- or (err_shape [2 ] != len (self . data ))
1140
+ or (err_shape [2 ] != len (data ))
1124
1141
):
1125
1142
raise ValueError (
1126
1143
"Asymmetrical error bars should be provided "
1127
- f"with the shape ({ self . nseries } , 2, { len (self . data )} )"
1144
+ f"with the shape ({ nseries } , 2, { len (data )} )"
1128
1145
)
1129
1146
1130
1147
# broadcast errors to each data series
1131
1148
if len (err ) == 1 :
1132
- err = np .tile (err , (self . nseries , 1 ))
1149
+ err = np .tile (err , (nseries , 1 ))
1133
1150
1134
1151
elif is_number (err ):
1135
1152
err = np .tile (
1136
1153
[err ], # pyright: ignore[reportGeneralTypeIssues]
1137
- (self . nseries , len (self . data )),
1154
+ (nseries , len (data )),
1138
1155
)
1139
1156
1140
1157
else :
1141
1158
msg = f"No valid { label } detected"
1142
1159
raise ValueError (msg )
1143
1160
1144
- return err
1161
+ return err , data # pyright: ignore[reportGeneralTypeIssues]
1145
1162
1146
1163
@final
1147
1164
def _get_errorbars (
@@ -1215,8 +1232,7 @@ def __init__(self, data, x, y, **kwargs) -> None:
1215
1232
self .y = y
1216
1233
1217
1234
@final
1218
- @property
1219
- def nseries (self ) -> int :
1235
+ def _get_nseries (self , data : Series | DataFrame ) -> int :
1220
1236
return 1
1221
1237
1222
1238
@final
0 commit comments