1
1
import re
2
- from typing import TYPE_CHECKING , List , Optional
2
+ from typing import TYPE_CHECKING , List , Optional , Tuple
3
3
import warnings
4
4
5
5
from matplotlib .artist import Artist
45
45
46
46
if TYPE_CHECKING :
47
47
from matplotlib .axes import Axes
48
+ from matplotlib .axis import Axis
48
49
49
50
50
51
class MPLPlot :
@@ -68,16 +69,10 @@ def _kind(self):
68
69
_pop_attributes = [
69
70
"label" ,
70
71
"style" ,
71
- "logy" ,
72
- "logx" ,
73
- "loglog" ,
74
72
"mark_right" ,
75
73
"stacked" ,
76
74
]
77
75
_attr_defaults = {
78
- "logy" : False ,
79
- "logx" : False ,
80
- "loglog" : False ,
81
76
"mark_right" : True ,
82
77
"stacked" : False ,
83
78
}
@@ -167,6 +162,9 @@ def __init__(
167
162
self .legend_handles : List [Artist ] = []
168
163
self .legend_labels : List [Label ] = []
169
164
165
+ self .logx = kwds .pop ("logx" , False )
166
+ self .logy = kwds .pop ("logy" , False )
167
+ self .loglog = kwds .pop ("loglog" , False )
170
168
for attr in self ._pop_attributes :
171
169
value = kwds .pop (attr , self ._attr_defaults .get (attr , None ))
172
170
setattr (self , attr , value )
@@ -283,11 +281,11 @@ def generate(self):
283
281
def _args_adjust (self ):
284
282
pass
285
283
286
- def _has_plotted_object (self , ax ) :
284
+ def _has_plotted_object (self , ax : "Axes" ) -> bool :
287
285
"""check whether ax has data"""
288
286
return len (ax .lines ) != 0 or len (ax .artists ) != 0 or len (ax .containers ) != 0
289
287
290
- def _maybe_right_yaxis (self , ax , axes_num ):
288
+ def _maybe_right_yaxis (self , ax : "Axes" , axes_num ):
291
289
if not self .on_right (axes_num ):
292
290
# secondary axes may be passed via ax kw
293
291
return self ._get_ax_layer (ax )
@@ -523,7 +521,7 @@ def _adorn_subplots(self):
523
521
raise ValueError (msg )
524
522
self .axes [0 ].set_title (self .title )
525
523
526
- def _apply_axis_properties (self , axis , rot = None , fontsize = None ):
524
+ def _apply_axis_properties (self , axis : "Axis" , rot = None , fontsize = None ):
527
525
"""
528
526
Tick creation within matplotlib is reasonably expensive and is
529
527
internally deferred until accessed as Ticks are created/destroyed
@@ -540,7 +538,7 @@ def _apply_axis_properties(self, axis, rot=None, fontsize=None):
540
538
label .set_fontsize (fontsize )
541
539
542
540
@property
543
- def legend_title (self ):
541
+ def legend_title (self ) -> Optional [ str ] :
544
542
if not isinstance (self .data .columns , ABCMultiIndex ):
545
543
name = self .data .columns .name
546
544
if name is not None :
@@ -591,7 +589,7 @@ def _make_legend(self):
591
589
if ax .get_visible ():
592
590
ax .legend (loc = "best" )
593
591
594
- def _get_ax_legend_handle (self , ax ):
592
+ def _get_ax_legend_handle (self , ax : "Axes" ):
595
593
"""
596
594
Take in axes and return ax, legend and handle under different scenarios
597
595
"""
@@ -616,7 +614,7 @@ def plt(self):
616
614
617
615
_need_to_set_index = False
618
616
619
- def _get_xticks (self , convert_period = False ):
617
+ def _get_xticks (self , convert_period : bool = False ):
620
618
index = self .data .index
621
619
is_datetype = index .inferred_type in ("datetime" , "date" , "datetime64" , "time" )
622
620
@@ -646,7 +644,7 @@ def _get_xticks(self, convert_period=False):
646
644
647
645
@classmethod
648
646
@register_pandas_matplotlib_converters
649
- def _plot (cls , ax , x , y , style = None , is_errorbar = False , ** kwds ):
647
+ def _plot (cls , ax : "Axes" , x , y , style = None , is_errorbar : bool = False , ** kwds ):
650
648
mask = isna (y )
651
649
if mask .any ():
652
650
y = np .ma .array (y )
@@ -667,10 +665,10 @@ def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
667
665
if style is not None :
668
666
args = (x , y , style )
669
667
else :
670
- args = (x , y )
668
+ args = (x , y ) # type:ignore[assignment]
671
669
return ax .plot (* args , ** kwds )
672
670
673
- def _get_index_name (self ):
671
+ def _get_index_name (self ) -> Optional [ str ] :
674
672
if isinstance (self .data .index , ABCMultiIndex ):
675
673
name = self .data .index .names
676
674
if com .any_not_none (* name ):
@@ -877,7 +875,7 @@ def _get_subplots(self):
877
875
ax for ax in self .axes [0 ].get_figure ().get_axes () if isinstance (ax , Subplot )
878
876
]
879
877
880
- def _get_axes_layout (self ):
878
+ def _get_axes_layout (self ) -> Tuple [ int , int ] :
881
879
axes = self ._get_subplots ()
882
880
x_set = set ()
883
881
y_set = set ()
@@ -916,15 +914,15 @@ def __init__(self, data, x, y, **kwargs):
916
914
self .y = y
917
915
918
916
@property
919
- def nseries (self ):
917
+ def nseries (self ) -> int :
920
918
return 1
921
919
922
- def _post_plot_logic (self , ax , data ):
920
+ def _post_plot_logic (self , ax : "Axes" , data ):
923
921
x , y = self .x , self .y
924
922
ax .set_ylabel (pprint_thing (y ))
925
923
ax .set_xlabel (pprint_thing (x ))
926
924
927
- def _plot_colorbar (self , ax , ** kwds ):
925
+ def _plot_colorbar (self , ax : "Axes" , ** kwds ):
928
926
# Addresses issues #10611 and #10678:
929
927
# When plotting scatterplots and hexbinplots in IPython
930
928
# inline backend the colorbar axis height tends not to
@@ -1080,7 +1078,7 @@ def __init__(self, data, **kwargs):
1080
1078
if "x_compat" in self .kwds :
1081
1079
self .x_compat = bool (self .kwds .pop ("x_compat" ))
1082
1080
1083
- def _is_ts_plot (self ):
1081
+ def _is_ts_plot (self ) -> bool :
1084
1082
# this is slightly deceptive
1085
1083
return not self .x_compat and self .use_index and self ._use_dynamic_x ()
1086
1084
@@ -1139,7 +1137,9 @@ def _make_plot(self):
1139
1137
ax .set_xlim (left , right )
1140
1138
1141
1139
@classmethod
1142
- def _plot (cls , ax , x , y , style = None , column_num = None , stacking_id = None , ** kwds ):
1140
+ def _plot (
1141
+ cls , ax : "Axes" , x , y , style = None , column_num = None , stacking_id = None , ** kwds
1142
+ ):
1143
1143
# column_num is used to get the target column from plotf in line and
1144
1144
# area plots
1145
1145
if column_num == 0 :
@@ -1183,7 +1183,7 @@ def _get_stacking_id(self):
1183
1183
return None
1184
1184
1185
1185
@classmethod
1186
- def _initialize_stacker (cls , ax , stacking_id , n ):
1186
+ def _initialize_stacker (cls , ax : "Axes" , stacking_id , n : int ):
1187
1187
if stacking_id is None :
1188
1188
return
1189
1189
if not hasattr (ax , "_stacker_pos_prior" ):
@@ -1194,7 +1194,7 @@ def _initialize_stacker(cls, ax, stacking_id, n):
1194
1194
ax ._stacker_neg_prior [stacking_id ] = np .zeros (n )
1195
1195
1196
1196
@classmethod
1197
- def _get_stacked_values (cls , ax , stacking_id , values , label ):
1197
+ def _get_stacked_values (cls , ax : "Axes" , stacking_id , values , label ):
1198
1198
if stacking_id is None :
1199
1199
return values
1200
1200
if not hasattr (ax , "_stacker_pos_prior" ):
@@ -1213,15 +1213,15 @@ def _get_stacked_values(cls, ax, stacking_id, values, label):
1213
1213
)
1214
1214
1215
1215
@classmethod
1216
- def _update_stacker (cls , ax , stacking_id , values ):
1216
+ def _update_stacker (cls , ax : "Axes" , stacking_id , values ):
1217
1217
if stacking_id is None :
1218
1218
return
1219
1219
if (values >= 0 ).all ():
1220
1220
ax ._stacker_pos_prior [stacking_id ] += values
1221
1221
elif (values <= 0 ).all ():
1222
1222
ax ._stacker_neg_prior [stacking_id ] += values
1223
1223
1224
- def _post_plot_logic (self , ax , data ):
1224
+ def _post_plot_logic (self , ax : "Axes" , data ):
1225
1225
from matplotlib .ticker import FixedLocator
1226
1226
1227
1227
def get_label (i ):
@@ -1276,7 +1276,7 @@ def __init__(self, data, **kwargs):
1276
1276
@classmethod
1277
1277
def _plot (
1278
1278
cls ,
1279
- ax ,
1279
+ ax : "Axes" ,
1280
1280
x ,
1281
1281
y ,
1282
1282
style = None ,
@@ -1318,7 +1318,7 @@ def _plot(
1318
1318
res = [rect ]
1319
1319
return res
1320
1320
1321
- def _post_plot_logic (self , ax , data ):
1321
+ def _post_plot_logic (self , ax : "Axes" , data ):
1322
1322
LinePlot ._post_plot_logic (self , ax , data )
1323
1323
1324
1324
if self .ylim is None :
@@ -1372,7 +1372,7 @@ def _args_adjust(self):
1372
1372
self .left = np .array (self .left )
1373
1373
1374
1374
@classmethod
1375
- def _plot (cls , ax , x , y , w , start = 0 , log = False , ** kwds ):
1375
+ def _plot (cls , ax : "Axes" , x , y , w , start = 0 , log = False , ** kwds ):
1376
1376
return ax .bar (x , y , w , bottom = start , log = log , ** kwds )
1377
1377
1378
1378
@property
@@ -1454,7 +1454,7 @@ def _make_plot(self):
1454
1454
)
1455
1455
self ._add_legend_handle (rect , label , index = i )
1456
1456
1457
- def _post_plot_logic (self , ax , data ):
1457
+ def _post_plot_logic (self , ax : "Axes" , data ):
1458
1458
if self .use_index :
1459
1459
str_index = [pprint_thing (key ) for key in data .index ]
1460
1460
else :
@@ -1466,7 +1466,7 @@ def _post_plot_logic(self, ax, data):
1466
1466
1467
1467
self ._decorate_ticks (ax , name , str_index , s_edge , e_edge )
1468
1468
1469
- def _decorate_ticks (self , ax , name , ticklabels , start_edge , end_edge ):
1469
+ def _decorate_ticks (self , ax : "Axes" , name , ticklabels , start_edge , end_edge ):
1470
1470
ax .set_xlim ((start_edge , end_edge ))
1471
1471
1472
1472
if self .xticks is not None :
@@ -1489,10 +1489,10 @@ def _start_base(self):
1489
1489
return self .left
1490
1490
1491
1491
@classmethod
1492
- def _plot (cls , ax , x , y , w , start = 0 , log = False , ** kwds ):
1492
+ def _plot (cls , ax : "Axes" , x , y , w , start = 0 , log = False , ** kwds ):
1493
1493
return ax .barh (x , y , w , left = start , log = log , ** kwds )
1494
1494
1495
- def _decorate_ticks (self , ax , name , ticklabels , start_edge , end_edge ):
1495
+ def _decorate_ticks (self , ax : "Axes" , name , ticklabels , start_edge , end_edge ):
1496
1496
# horizontal bars
1497
1497
ax .set_ylim ((start_edge , end_edge ))
1498
1498
ax .set_yticks (self .tick_pos )
0 commit comments