Skip to content

Commit 303e40b

Browse files
authored
TYP: annotate plotting._matplotlib.misc (#36017)
1 parent 329e1c7 commit 303e40b

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

pandas/plotting/_matplotlib/misc.py

+47-15
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
import random
2+
from typing import TYPE_CHECKING, Dict, List, Optional, Set
23

34
import matplotlib.lines as mlines
45
import matplotlib.patches as patches
56
import numpy as np
67

8+
from pandas._typing import Label
9+
710
from pandas.core.dtypes.missing import notna
811

912
from pandas.io.formats.printing import pprint_thing
1013
from pandas.plotting._matplotlib.style import _get_standard_colors
1114
from pandas.plotting._matplotlib.tools import _set_ticks_props, _subplots
1215

16+
if TYPE_CHECKING:
17+
from matplotlib.axes import Axes
18+
from matplotlib.figure import Figure
19+
20+
from pandas import DataFrame, Series
21+
1322

1423
def scatter_matrix(
15-
frame,
24+
frame: "DataFrame",
1625
alpha=0.5,
1726
figsize=None,
1827
ax=None,
@@ -114,7 +123,14 @@ def _get_marker_compat(marker):
114123
return marker
115124

116125

117-
def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
126+
def radviz(
127+
frame: "DataFrame",
128+
class_column,
129+
ax: Optional["Axes"] = None,
130+
color=None,
131+
colormap=None,
132+
**kwds,
133+
) -> "Axes":
118134
import matplotlib.pyplot as plt
119135

120136
def normalize(series):
@@ -130,7 +146,7 @@ def normalize(series):
130146
if ax is None:
131147
ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
132148

133-
to_plot = {}
149+
to_plot: Dict[Label, List[List]] = {}
134150
colors = _get_standard_colors(
135151
num_colors=len(classes), colormap=colormap, color_type="random", color=color
136152
)
@@ -197,8 +213,14 @@ def normalize(series):
197213

198214

199215
def andrews_curves(
200-
frame, class_column, ax=None, samples=200, color=None, colormap=None, **kwds
201-
):
216+
frame: "DataFrame",
217+
class_column,
218+
ax: Optional["Axes"] = None,
219+
samples: int = 200,
220+
color=None,
221+
colormap=None,
222+
**kwds,
223+
) -> "Axes":
202224
import matplotlib.pyplot as plt
203225

204226
def function(amplitudes):
@@ -231,7 +253,7 @@ def f(t):
231253
classes = frame[class_column].drop_duplicates()
232254
df = frame.drop(class_column, axis=1)
233255
t = np.linspace(-np.pi, np.pi, samples)
234-
used_legends = set()
256+
used_legends: Set[str] = set()
235257

236258
color_values = _get_standard_colors(
237259
num_colors=len(classes), colormap=colormap, color_type="random", color=color
@@ -256,7 +278,13 @@ def f(t):
256278
return ax
257279

258280

259-
def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
281+
def bootstrap_plot(
282+
series: "Series",
283+
fig: Optional["Figure"] = None,
284+
size: int = 50,
285+
samples: int = 500,
286+
**kwds,
287+
) -> "Figure":
260288

261289
import matplotlib.pyplot as plt
262290

@@ -306,19 +334,19 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
306334

307335

308336
def parallel_coordinates(
309-
frame,
337+
frame: "DataFrame",
310338
class_column,
311339
cols=None,
312-
ax=None,
340+
ax: Optional["Axes"] = None,
313341
color=None,
314342
use_columns=False,
315343
xticks=None,
316344
colormap=None,
317-
axvlines=True,
345+
axvlines: bool = True,
318346
axvlines_kwds=None,
319-
sort_labels=False,
347+
sort_labels: bool = False,
320348
**kwds,
321-
):
349+
) -> "Axes":
322350
import matplotlib.pyplot as plt
323351

324352
if axvlines_kwds is None:
@@ -333,7 +361,7 @@ def parallel_coordinates(
333361
else:
334362
df = frame[cols]
335363

336-
used_legends = set()
364+
used_legends: Set[str] = set()
337365

338366
ncols = len(df.columns)
339367

@@ -385,7 +413,9 @@ def parallel_coordinates(
385413
return ax
386414

387415

388-
def lag_plot(series, lag=1, ax=None, **kwds):
416+
def lag_plot(
417+
series: "Series", lag: int = 1, ax: Optional["Axes"] = None, **kwds
418+
) -> "Axes":
389419
# workaround because `c='b'` is hardcoded in matplotlib's scatter method
390420
import matplotlib.pyplot as plt
391421

@@ -402,7 +432,9 @@ def lag_plot(series, lag=1, ax=None, **kwds):
402432
return ax
403433

404434

405-
def autocorrelation_plot(series, ax=None, **kwds):
435+
def autocorrelation_plot(
436+
series: "Series", ax: Optional["Axes"] = None, **kwds
437+
) -> "Axes":
406438
import matplotlib.pyplot as plt
407439

408440
n = len(series)

pandas/plotting/_matplotlib/style.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def _get_standard_colors(
14-
num_colors=None, colormap=None, color_type="default", color=None
14+
num_colors=None, colormap=None, color_type: str = "default", color=None
1515
):
1616
import matplotlib.pyplot as plt
1717

0 commit comments

Comments
 (0)