1
1
# being a bit too dynamic
2
2
from math import ceil
3
- from typing import TYPE_CHECKING , Tuple
3
+ from typing import TYPE_CHECKING , Iterable , List , Sequence , Tuple , Union
4
4
import warnings
5
5
6
6
import matplotlib .table
15
15
from pandas .plotting ._matplotlib import compat
16
16
17
17
if TYPE_CHECKING :
18
+ from matplotlib .axes import Axes
19
+ from matplotlib .axis import Axis
20
+ from matplotlib .lines import Line2D # noqa:F401
18
21
from matplotlib .table import Table
19
22
20
23
21
- def format_date_labels (ax , rot ):
24
+ def format_date_labels (ax : "Axes" , rot ):
22
25
# mini version of autofmt_xdate
23
26
for label in ax .get_xticklabels ():
24
27
label .set_ha ("right" )
@@ -278,7 +281,7 @@ def _subplots(
278
281
return fig , axes
279
282
280
283
281
- def _remove_labels_from_axis (axis ):
284
+ def _remove_labels_from_axis (axis : "Axis" ):
282
285
for t in axis .get_majorticklabels ():
283
286
t .set_visible (False )
284
287
@@ -294,7 +297,15 @@ def _remove_labels_from_axis(axis):
294
297
axis .get_label ().set_visible (False )
295
298
296
299
297
- def _handle_shared_axes (axarr , nplots , naxes , nrows , ncols , sharex , sharey ):
300
+ def _handle_shared_axes (
301
+ axarr : Iterable ["Axes" ],
302
+ nplots : int ,
303
+ naxes : int ,
304
+ nrows : int ,
305
+ ncols : int ,
306
+ sharex : bool ,
307
+ sharey : bool ,
308
+ ):
298
309
if nplots > 1 :
299
310
if compat ._mpl_ge_3_2_0 ():
300
311
row_num = lambda x : x .get_subplotspec ().rowspan .start
@@ -340,15 +351,21 @@ def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey):
340
351
_remove_labels_from_axis (ax .yaxis )
341
352
342
353
343
- def _flatten (axes ) :
354
+ def _flatten (axes : Union [ "Axes" , Sequence [ "Axes" ]]) -> Sequence [ "Axes" ] :
344
355
if not is_list_like (axes ):
345
356
return np .array ([axes ])
346
357
elif isinstance (axes , (np .ndarray , ABCIndexClass )):
347
358
return axes .ravel ()
348
359
return np .array (axes )
349
360
350
361
351
- def _set_ticks_props (axes , xlabelsize = None , xrot = None , ylabelsize = None , yrot = None ):
362
+ def _set_ticks_props (
363
+ axes : Union ["Axes" , Sequence ["Axes" ]],
364
+ xlabelsize = None ,
365
+ xrot = None ,
366
+ ylabelsize = None ,
367
+ yrot = None ,
368
+ ):
352
369
import matplotlib .pyplot as plt
353
370
354
371
for ax in _flatten (axes ):
@@ -363,7 +380,7 @@ def _set_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=Non
363
380
return axes
364
381
365
382
366
- def _get_all_lines (ax ) :
383
+ def _get_all_lines (ax : "Axes" ) -> List [ "Line2D" ] :
367
384
lines = ax .get_lines ()
368
385
369
386
if hasattr (ax , "right_ax" ):
@@ -375,7 +392,7 @@ def _get_all_lines(ax):
375
392
return lines
376
393
377
394
378
- def _get_xlim (lines ) -> Tuple [float , float ]:
395
+ def _get_xlim (lines : Iterable [ "Line2D" ] ) -> Tuple [float , float ]:
379
396
left , right = np .inf , - np .inf
380
397
for l in lines :
381
398
x = l .get_xdata (orig = False )
0 commit comments