1
1
import random
2
+ from typing import TYPE_CHECKING , Dict , List , Optional , Set
2
3
3
4
import matplotlib .lines as mlines
4
5
import matplotlib .patches as patches
5
6
import numpy as np
6
7
8
+ from pandas ._typing import Label
9
+
7
10
from pandas .core .dtypes .missing import notna
8
11
9
12
from pandas .io .formats .printing import pprint_thing
10
13
from pandas .plotting ._matplotlib .style import _get_standard_colors
11
14
from pandas .plotting ._matplotlib .tools import _set_ticks_props , _subplots
12
15
16
+ if TYPE_CHECKING :
17
+ from matplotlib .axes import Axes
18
+ from matplotlib .figure import Figure
19
+
20
+ from pandas import DataFrame , Series
21
+
13
22
14
23
def scatter_matrix (
15
- frame ,
24
+ frame : "DataFrame" ,
16
25
alpha = 0.5 ,
17
26
figsize = None ,
18
27
ax = None ,
@@ -114,7 +123,14 @@ def _get_marker_compat(marker):
114
123
return marker
115
124
116
125
117
- def radviz (frame , class_column , ax = None , color = None , colormap = None , ** kwds ):
126
+ def radviz (
127
+ frame : "DataFrame" ,
128
+ class_column ,
129
+ ax : Optional ["Axes" ] = None ,
130
+ color = None ,
131
+ colormap = None ,
132
+ ** kwds ,
133
+ ) -> "Axes" :
118
134
import matplotlib .pyplot as plt
119
135
120
136
def normalize (series ):
@@ -130,7 +146,7 @@ def normalize(series):
130
146
if ax is None :
131
147
ax = plt .gca (xlim = [- 1 , 1 ], ylim = [- 1 , 1 ])
132
148
133
- to_plot = {}
149
+ to_plot : Dict [ Label , List [ List ]] = {}
134
150
colors = _get_standard_colors (
135
151
num_colors = len (classes ), colormap = colormap , color_type = "random" , color = color
136
152
)
@@ -197,8 +213,14 @@ def normalize(series):
197
213
198
214
199
215
def andrews_curves (
200
- frame , class_column , ax = None , samples = 200 , color = None , colormap = None , ** kwds
201
- ):
216
+ frame : "DataFrame" ,
217
+ class_column ,
218
+ ax : Optional ["Axes" ] = None ,
219
+ samples : int = 200 ,
220
+ color = None ,
221
+ colormap = None ,
222
+ ** kwds ,
223
+ ) -> "Axes" :
202
224
import matplotlib .pyplot as plt
203
225
204
226
def function (amplitudes ):
@@ -231,7 +253,7 @@ def f(t):
231
253
classes = frame [class_column ].drop_duplicates ()
232
254
df = frame .drop (class_column , axis = 1 )
233
255
t = np .linspace (- np .pi , np .pi , samples )
234
- used_legends = set ()
256
+ used_legends : Set [ str ] = set ()
235
257
236
258
color_values = _get_standard_colors (
237
259
num_colors = len (classes ), colormap = colormap , color_type = "random" , color = color
@@ -256,7 +278,13 @@ def f(t):
256
278
return ax
257
279
258
280
259
- def bootstrap_plot (series , fig = None , size = 50 , samples = 500 , ** kwds ):
281
+ def bootstrap_plot (
282
+ series : "Series" ,
283
+ fig : Optional ["Figure" ] = None ,
284
+ size : int = 50 ,
285
+ samples : int = 500 ,
286
+ ** kwds ,
287
+ ) -> "Figure" :
260
288
261
289
import matplotlib .pyplot as plt
262
290
@@ -306,19 +334,19 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
306
334
307
335
308
336
def parallel_coordinates (
309
- frame ,
337
+ frame : "DataFrame" ,
310
338
class_column ,
311
339
cols = None ,
312
- ax = None ,
340
+ ax : Optional [ "Axes" ] = None ,
313
341
color = None ,
314
342
use_columns = False ,
315
343
xticks = None ,
316
344
colormap = None ,
317
- axvlines = True ,
345
+ axvlines : bool = True ,
318
346
axvlines_kwds = None ,
319
- sort_labels = False ,
347
+ sort_labels : bool = False ,
320
348
** kwds ,
321
- ):
349
+ ) -> "Axes" :
322
350
import matplotlib .pyplot as plt
323
351
324
352
if axvlines_kwds is None :
@@ -333,7 +361,7 @@ def parallel_coordinates(
333
361
else :
334
362
df = frame [cols ]
335
363
336
- used_legends = set ()
364
+ used_legends : Set [ str ] = set ()
337
365
338
366
ncols = len (df .columns )
339
367
@@ -385,7 +413,9 @@ def parallel_coordinates(
385
413
return ax
386
414
387
415
388
- def lag_plot (series , lag = 1 , ax = None , ** kwds ):
416
+ def lag_plot (
417
+ series : "Series" , lag : int = 1 , ax : Optional ["Axes" ] = None , ** kwds
418
+ ) -> "Axes" :
389
419
# workaround because `c='b'` is hardcoded in matplotlib's scatter method
390
420
import matplotlib .pyplot as plt
391
421
@@ -402,7 +432,9 @@ def lag_plot(series, lag=1, ax=None, **kwds):
402
432
return ax
403
433
404
434
405
- def autocorrelation_plot (series , ax = None , ** kwds ):
435
+ def autocorrelation_plot (
436
+ series : "Series" , ax : Optional ["Axes" ] = None , ** kwds
437
+ ) -> "Axes" :
406
438
import matplotlib .pyplot as plt
407
439
408
440
n = len (series )
0 commit comments