Skip to content

Commit 3bb135d

Browse files
jbrockmendelKevin D Smith
authored and
Kevin D Smith
committed
TYP: Annotate plotting stacker (pandas-dev#36016)
1 parent a89f4d9 commit 3bb135d

File tree

3 files changed

+45
-36
lines changed

3 files changed

+45
-36
lines changed

pandas/plotting/_matplotlib/boxplot.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import namedtuple
2+
from typing import TYPE_CHECKING
23
import warnings
34

45
from matplotlib.artist import setp
@@ -14,6 +15,9 @@
1415
from pandas.plotting._matplotlib.style import _get_standard_colors
1516
from pandas.plotting._matplotlib.tools import _flatten, _subplots
1617

18+
if TYPE_CHECKING:
19+
from matplotlib.axes import Axes
20+
1721

1822
class BoxPlot(LinePlot):
1923
_kind = "box"
@@ -150,7 +154,7 @@ def _make_plot(self):
150154
labels = [pprint_thing(key) for key in range(len(labels))]
151155
self._set_ticklabels(ax, labels)
152156

153-
def _set_ticklabels(self, ax, labels):
157+
def _set_ticklabels(self, ax: "Axes", labels):
154158
if self.orientation == "vertical":
155159
ax.set_xticklabels(labels)
156160
else:
@@ -292,7 +296,7 @@ def maybe_color_bp(bp, **kwds):
292296
if not kwds.get("capprops"):
293297
setp(bp["caps"], color=colors[3], alpha=1)
294298

295-
def plot_group(keys, values, ax):
299+
def plot_group(keys, values, ax: "Axes"):
296300
keys = [pprint_thing(x) for x in keys]
297301
values = [np.asarray(remove_na_arraylike(v), dtype=object) for v in values]
298302
bp = ax.boxplot(values, **kwds)

pandas/plotting/_matplotlib/core.py

+33-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import TYPE_CHECKING, List, Optional
2+
from typing import TYPE_CHECKING, List, Optional, Tuple
33
import warnings
44

55
from matplotlib.artist import Artist
@@ -45,6 +45,7 @@
4545

4646
if TYPE_CHECKING:
4747
from matplotlib.axes import Axes
48+
from matplotlib.axis import Axis
4849

4950

5051
class MPLPlot:
@@ -68,16 +69,10 @@ def _kind(self):
6869
_pop_attributes = [
6970
"label",
7071
"style",
71-
"logy",
72-
"logx",
73-
"loglog",
7472
"mark_right",
7573
"stacked",
7674
]
7775
_attr_defaults = {
78-
"logy": False,
79-
"logx": False,
80-
"loglog": False,
8176
"mark_right": True,
8277
"stacked": False,
8378
}
@@ -167,6 +162,9 @@ def __init__(
167162
self.legend_handles: List[Artist] = []
168163
self.legend_labels: List[Label] = []
169164

165+
self.logx = kwds.pop("logx", False)
166+
self.logy = kwds.pop("logy", False)
167+
self.loglog = kwds.pop("loglog", False)
170168
for attr in self._pop_attributes:
171169
value = kwds.pop(attr, self._attr_defaults.get(attr, None))
172170
setattr(self, attr, value)
@@ -283,11 +281,11 @@ def generate(self):
283281
def _args_adjust(self):
284282
pass
285283

286-
def _has_plotted_object(self, ax):
284+
def _has_plotted_object(self, ax: "Axes") -> bool:
287285
"""check whether ax has data"""
288286
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0
289287

290-
def _maybe_right_yaxis(self, ax, axes_num):
288+
def _maybe_right_yaxis(self, ax: "Axes", axes_num):
291289
if not self.on_right(axes_num):
292290
# secondary axes may be passed via ax kw
293291
return self._get_ax_layer(ax)
@@ -523,7 +521,7 @@ def _adorn_subplots(self):
523521
raise ValueError(msg)
524522
self.axes[0].set_title(self.title)
525523

526-
def _apply_axis_properties(self, axis, rot=None, fontsize=None):
524+
def _apply_axis_properties(self, axis: "Axis", rot=None, fontsize=None):
527525
"""
528526
Tick creation within matplotlib is reasonably expensive and is
529527
internally deferred until accessed as Ticks are created/destroyed
@@ -540,7 +538,7 @@ def _apply_axis_properties(self, axis, rot=None, fontsize=None):
540538
label.set_fontsize(fontsize)
541539

542540
@property
543-
def legend_title(self):
541+
def legend_title(self) -> Optional[str]:
544542
if not isinstance(self.data.columns, ABCMultiIndex):
545543
name = self.data.columns.name
546544
if name is not None:
@@ -591,7 +589,7 @@ def _make_legend(self):
591589
if ax.get_visible():
592590
ax.legend(loc="best")
593591

594-
def _get_ax_legend_handle(self, ax):
592+
def _get_ax_legend_handle(self, ax: "Axes"):
595593
"""
596594
Take in axes and return ax, legend and handle under different scenarios
597595
"""
@@ -616,7 +614,7 @@ def plt(self):
616614

617615
_need_to_set_index = False
618616

619-
def _get_xticks(self, convert_period=False):
617+
def _get_xticks(self, convert_period: bool = False):
620618
index = self.data.index
621619
is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")
622620

@@ -646,7 +644,7 @@ def _get_xticks(self, convert_period=False):
646644

647645
@classmethod
648646
@register_pandas_matplotlib_converters
649-
def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
647+
def _plot(cls, ax: "Axes", x, y, style=None, is_errorbar: bool = False, **kwds):
650648
mask = isna(y)
651649
if mask.any():
652650
y = np.ma.array(y)
@@ -667,10 +665,10 @@ def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
667665
if style is not None:
668666
args = (x, y, style)
669667
else:
670-
args = (x, y)
668+
args = (x, y) # type:ignore[assignment]
671669
return ax.plot(*args, **kwds)
672670

673-
def _get_index_name(self):
671+
def _get_index_name(self) -> Optional[str]:
674672
if isinstance(self.data.index, ABCMultiIndex):
675673
name = self.data.index.names
676674
if com.any_not_none(*name):
@@ -877,7 +875,7 @@ def _get_subplots(self):
877875
ax for ax in self.axes[0].get_figure().get_axes() if isinstance(ax, Subplot)
878876
]
879877

880-
def _get_axes_layout(self):
878+
def _get_axes_layout(self) -> Tuple[int, int]:
881879
axes = self._get_subplots()
882880
x_set = set()
883881
y_set = set()
@@ -916,15 +914,15 @@ def __init__(self, data, x, y, **kwargs):
916914
self.y = y
917915

918916
@property
919-
def nseries(self):
917+
def nseries(self) -> int:
920918
return 1
921919

922-
def _post_plot_logic(self, ax, data):
920+
def _post_plot_logic(self, ax: "Axes", data):
923921
x, y = self.x, self.y
924922
ax.set_ylabel(pprint_thing(y))
925923
ax.set_xlabel(pprint_thing(x))
926924

927-
def _plot_colorbar(self, ax, **kwds):
925+
def _plot_colorbar(self, ax: "Axes", **kwds):
928926
# Addresses issues #10611 and #10678:
929927
# When plotting scatterplots and hexbinplots in IPython
930928
# inline backend the colorbar axis height tends not to
@@ -1080,7 +1078,7 @@ def __init__(self, data, **kwargs):
10801078
if "x_compat" in self.kwds:
10811079
self.x_compat = bool(self.kwds.pop("x_compat"))
10821080

1083-
def _is_ts_plot(self):
1081+
def _is_ts_plot(self) -> bool:
10841082
# this is slightly deceptive
10851083
return not self.x_compat and self.use_index and self._use_dynamic_x()
10861084

@@ -1139,7 +1137,9 @@ def _make_plot(self):
11391137
ax.set_xlim(left, right)
11401138

11411139
@classmethod
1142-
def _plot(cls, ax, x, y, style=None, column_num=None, stacking_id=None, **kwds):
1140+
def _plot(
1141+
cls, ax: "Axes", x, y, style=None, column_num=None, stacking_id=None, **kwds
1142+
):
11431143
# column_num is used to get the target column from plotf in line and
11441144
# area plots
11451145
if column_num == 0:
@@ -1183,7 +1183,7 @@ def _get_stacking_id(self):
11831183
return None
11841184

11851185
@classmethod
1186-
def _initialize_stacker(cls, ax, stacking_id, n):
1186+
def _initialize_stacker(cls, ax: "Axes", stacking_id, n: int):
11871187
if stacking_id is None:
11881188
return
11891189
if not hasattr(ax, "_stacker_pos_prior"):
@@ -1194,7 +1194,7 @@ def _initialize_stacker(cls, ax, stacking_id, n):
11941194
ax._stacker_neg_prior[stacking_id] = np.zeros(n)
11951195

11961196
@classmethod
1197-
def _get_stacked_values(cls, ax, stacking_id, values, label):
1197+
def _get_stacked_values(cls, ax: "Axes", stacking_id, values, label):
11981198
if stacking_id is None:
11991199
return values
12001200
if not hasattr(ax, "_stacker_pos_prior"):
@@ -1213,15 +1213,15 @@ def _get_stacked_values(cls, ax, stacking_id, values, label):
12131213
)
12141214

12151215
@classmethod
1216-
def _update_stacker(cls, ax, stacking_id, values):
1216+
def _update_stacker(cls, ax: "Axes", stacking_id, values):
12171217
if stacking_id is None:
12181218
return
12191219
if (values >= 0).all():
12201220
ax._stacker_pos_prior[stacking_id] += values
12211221
elif (values <= 0).all():
12221222
ax._stacker_neg_prior[stacking_id] += values
12231223

1224-
def _post_plot_logic(self, ax, data):
1224+
def _post_plot_logic(self, ax: "Axes", data):
12251225
from matplotlib.ticker import FixedLocator
12261226

12271227
def get_label(i):
@@ -1276,7 +1276,7 @@ def __init__(self, data, **kwargs):
12761276
@classmethod
12771277
def _plot(
12781278
cls,
1279-
ax,
1279+
ax: "Axes",
12801280
x,
12811281
y,
12821282
style=None,
@@ -1318,7 +1318,7 @@ def _plot(
13181318
res = [rect]
13191319
return res
13201320

1321-
def _post_plot_logic(self, ax, data):
1321+
def _post_plot_logic(self, ax: "Axes", data):
13221322
LinePlot._post_plot_logic(self, ax, data)
13231323

13241324
if self.ylim is None:
@@ -1372,7 +1372,7 @@ def _args_adjust(self):
13721372
self.left = np.array(self.left)
13731373

13741374
@classmethod
1375-
def _plot(cls, ax, x, y, w, start=0, log=False, **kwds):
1375+
def _plot(cls, ax: "Axes", x, y, w, start=0, log=False, **kwds):
13761376
return ax.bar(x, y, w, bottom=start, log=log, **kwds)
13771377

13781378
@property
@@ -1454,7 +1454,7 @@ def _make_plot(self):
14541454
)
14551455
self._add_legend_handle(rect, label, index=i)
14561456

1457-
def _post_plot_logic(self, ax, data):
1457+
def _post_plot_logic(self, ax: "Axes", data):
14581458
if self.use_index:
14591459
str_index = [pprint_thing(key) for key in data.index]
14601460
else:
@@ -1466,7 +1466,7 @@ def _post_plot_logic(self, ax, data):
14661466

14671467
self._decorate_ticks(ax, name, str_index, s_edge, e_edge)
14681468

1469-
def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge):
1469+
def _decorate_ticks(self, ax: "Axes", name, ticklabels, start_edge, end_edge):
14701470
ax.set_xlim((start_edge, end_edge))
14711471

14721472
if self.xticks is not None:
@@ -1489,10 +1489,10 @@ def _start_base(self):
14891489
return self.left
14901490

14911491
@classmethod
1492-
def _plot(cls, ax, x, y, w, start=0, log=False, **kwds):
1492+
def _plot(cls, ax: "Axes", x, y, w, start=0, log=False, **kwds):
14931493
return ax.barh(x, y, w, left=start, log=log, **kwds)
14941494

1495-
def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge):
1495+
def _decorate_ticks(self, ax: "Axes", name, ticklabels, start_edge, end_edge):
14961496
# horizontal bars
14971497
ax.set_ylim((start_edge, end_edge))
14981498
ax.set_yticks(self.tick_pos)

pandas/plotting/_matplotlib/hist.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import TYPE_CHECKING
2+
13
import numpy as np
24

35
from pandas.core.dtypes.common import is_integer, is_list_like
@@ -8,6 +10,9 @@
810
from pandas.plotting._matplotlib.core import LinePlot, MPLPlot
911
from pandas.plotting._matplotlib.tools import _flatten, _set_ticks_props, _subplots
1012

13+
if TYPE_CHECKING:
14+
from matplotlib.axes import Axes
15+
1116

1217
class HistPlot(LinePlot):
1318
_kind = "hist"
@@ -90,7 +95,7 @@ def _make_plot_keywords(self, kwds, y):
9095
kwds["bins"] = self.bins
9196
return kwds
9297

93-
def _post_plot_logic(self, ax, data):
98+
def _post_plot_logic(self, ax: "Axes", data):
9499
if self.orientation == "horizontal":
95100
ax.set_xlabel("Frequency")
96101
else:

0 commit comments

Comments
 (0)