Skip to content

Commit 696e9bd

Browse files
authored
TYP: plotting._matplotlib (#47311)
* TYP: plotting._matplotlib * somehow super causes issues * fix pickle issue: was accessing _kind on class * and the last plotting file * add timedelta
1 parent 15902bd commit 696e9bd

File tree

11 files changed

+113
-51
lines changed

11 files changed

+113
-51
lines changed

pandas/_typing.py

+3
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,6 @@ def closed(self) -> bool:
326326

327327
# quantile interpolation
328328
QuantileInterpolation = Literal["linear", "lower", "higher", "midpoint", "nearest"]
329+
330+
# plotting
331+
PlottingOrientation = Literal["horizontal", "vertical"]

pandas/core/generic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5727,7 +5727,7 @@ def _check_inplace_setting(self, value) -> bool_t:
57275727
return True
57285728

57295729
@final
5730-
def _get_numeric_data(self):
5730+
def _get_numeric_data(self: NDFrameT) -> NDFrameT:
57315731
return self._constructor(self._mgr.get_numeric_data()).__finalize__(self)
57325732

57335733
@final
@@ -10954,7 +10954,8 @@ def mad(
1095410954

1095510955
data = self._get_numeric_data()
1095610956
if axis == 0:
10957-
demeaned = data - data.mean(axis=0)
10957+
# error: Unsupported operand types for - ("NDFrame" and "float")
10958+
demeaned = data - data.mean(axis=0) # type: ignore[operator]
1095810959
else:
1095910960
demeaned = data.sub(data.mean(axis=1), axis=0)
1096010961
return np.abs(demeaned).mean(axis=axis, skipna=skipna)

pandas/plotting/_core.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1879,11 +1879,11 @@ def _get_plot_backend(backend: str | None = None):
18791879
-----
18801880
Modifies `_backends` with imported backend as a side effect.
18811881
"""
1882-
backend = backend or get_option("plotting.backend")
1882+
backend_str: str = backend or get_option("plotting.backend")
18831883

1884-
if backend in _backends:
1885-
return _backends[backend]
1884+
if backend_str in _backends:
1885+
return _backends[backend_str]
18861886

1887-
module = _load_backend(backend)
1888-
_backends[backend] = module
1887+
module = _load_backend(backend_str)
1888+
_backends[backend_str] = module
18891889
return module

pandas/plotting/_matplotlib/boxplot.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import (
44
TYPE_CHECKING,
5+
Literal,
56
NamedTuple,
67
)
78
import warnings
@@ -34,7 +35,10 @@
3435

3536

3637
class BoxPlot(LinePlot):
37-
_kind = "box"
38+
@property
39+
def _kind(self) -> Literal["box"]:
40+
return "box"
41+
3842
_layout_type = "horizontal"
3943

4044
_valid_return_types = (None, "axes", "dict", "both")

pandas/plotting/_matplotlib/converter.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ def _daily_finder(vmin, vmax, freq: BaseOffset):
574574
Period(ordinal=int(vmin), freq=freq),
575575
Period(ordinal=int(vmax), freq=freq),
576576
)
577+
assert isinstance(vmin, Period)
578+
assert isinstance(vmax, Period)
577579
span = vmax.ordinal - vmin.ordinal + 1
578580
dates_ = period_range(start=vmin, end=vmax, freq=freq)
579581
# Initialize the output
@@ -1073,7 +1075,9 @@ def __call__(self, x, pos=0) -> str:
10731075
fmt = self.formatdict.pop(x, "")
10741076
if isinstance(fmt, np.bytes_):
10751077
fmt = fmt.decode("utf-8")
1076-
return Period(ordinal=int(x), freq=self.freq).strftime(fmt)
1078+
period = Period(ordinal=int(x), freq=self.freq)
1079+
assert isinstance(period, Period)
1080+
return period.strftime(fmt)
10771081

10781082

10791083
class TimeSeries_TimedeltaFormatter(Formatter):

pandas/plotting/_matplotlib/core.py

+60-20
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
from __future__ import annotations
22

3+
from abc import (
4+
ABC,
5+
abstractmethod,
6+
)
37
from typing import (
48
TYPE_CHECKING,
59
Hashable,
610
Iterable,
11+
Literal,
712
Sequence,
813
)
914
import warnings
1015

1116
from matplotlib.artist import Artist
1217
import numpy as np
1318

14-
from pandas._typing import IndexLabel
19+
from pandas._typing import (
20+
IndexLabel,
21+
PlottingOrientation,
22+
)
1523
from pandas.errors import AbstractMethodError
1624
from pandas.util._decorators import cache_readonly
1725

@@ -78,7 +86,7 @@ def _color_in_style(style: str) -> bool:
7886
return not set(BASE_COLORS).isdisjoint(style)
7987

8088

81-
class MPLPlot:
89+
class MPLPlot(ABC):
8290
"""
8391
Base class for assembling a pandas plot using matplotlib
8492
@@ -89,13 +97,17 @@ class MPLPlot:
8997
"""
9098

9199
@property
92-
def _kind(self):
100+
@abstractmethod
101+
def _kind(self) -> str:
93102
"""Specify kind str. Must be overridden in child class"""
94103
raise NotImplementedError
95104

96105
_layout_type = "vertical"
97106
_default_rot = 0
98-
orientation: str | None = None
107+
108+
@property
109+
def orientation(self) -> str | None:
110+
return None
99111

100112
axes: np.ndarray # of Axes objects
101113

@@ -843,7 +855,9 @@ def _get_xticks(self, convert_period: bool = False):
843855

844856
@classmethod
845857
@register_pandas_matplotlib_converters
846-
def _plot(cls, ax: Axes, x, y, style=None, is_errorbar: bool = False, **kwds):
858+
def _plot(
859+
cls, ax: Axes, x, y: np.ndarray, style=None, is_errorbar: bool = False, **kwds
860+
):
847861
mask = isna(y)
848862
if mask.any():
849863
y = np.ma.array(y)
@@ -1101,7 +1115,7 @@ def _get_axes_layout(self) -> tuple[int, int]:
11011115
return (len(y_set), len(x_set))
11021116

11031117

1104-
class PlanePlot(MPLPlot):
1118+
class PlanePlot(MPLPlot, ABC):
11051119
"""
11061120
Abstract class for plotting on plane, currently scatter and hexbin.
11071121
"""
@@ -1159,7 +1173,9 @@ def _plot_colorbar(self, ax: Axes, **kwds):
11591173

11601174

11611175
class ScatterPlot(PlanePlot):
1162-
_kind = "scatter"
1176+
@property
1177+
def _kind(self) -> Literal["scatter"]:
1178+
return "scatter"
11631179

11641180
def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
11651181
if s is None:
@@ -1247,7 +1263,9 @@ def _make_plot(self):
12471263

12481264

12491265
class HexBinPlot(PlanePlot):
1250-
_kind = "hexbin"
1266+
@property
1267+
def _kind(self) -> Literal["hexbin"]:
1268+
return "hexbin"
12511269

12521270
def __init__(self, data, x, y, C=None, **kwargs) -> None:
12531271
super().__init__(data, x, y, **kwargs)
@@ -1277,9 +1295,15 @@ def _make_legend(self):
12771295

12781296

12791297
class LinePlot(MPLPlot):
1280-
_kind = "line"
12811298
_default_rot = 0
1282-
orientation = "vertical"
1299+
1300+
@property
1301+
def orientation(self) -> PlottingOrientation:
1302+
return "vertical"
1303+
1304+
@property
1305+
def _kind(self) -> Literal["line", "area", "hist", "kde", "box"]:
1306+
return "line"
12831307

12841308
def __init__(self, data, **kwargs) -> None:
12851309
from pandas.plotting import plot_params
@@ -1363,8 +1387,7 @@ def _plot( # type: ignore[override]
13631387
cls._update_stacker(ax, stacking_id, y)
13641388
return lines
13651389

1366-
@classmethod
1367-
def _ts_plot(cls, ax: Axes, x, data, style=None, **kwds):
1390+
def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
13681391
# accept x to be consistent with normal plot func,
13691392
# x is not passed to tsplot as it uses data.index as x coordinate
13701393
# column_num must be in kwds for stacking purpose
@@ -1377,9 +1400,9 @@ def _ts_plot(cls, ax: Axes, x, data, style=None, **kwds):
13771400
decorate_axes(ax.left_ax, freq, kwds)
13781401
if hasattr(ax, "right_ax"):
13791402
decorate_axes(ax.right_ax, freq, kwds)
1380-
ax._plot_data.append((data, cls._kind, kwds))
1403+
ax._plot_data.append((data, self._kind, kwds))
13811404

1382-
lines = cls._plot(ax, data.index, data.values, style=style, **kwds)
1405+
lines = self._plot(ax, data.index, data.values, style=style, **kwds)
13831406
# set date formatter, locators and rescale limits
13841407
format_dateaxis(ax, ax.freq, data.index)
13851408
return lines
@@ -1471,7 +1494,9 @@ def get_label(i):
14711494

14721495

14731496
class AreaPlot(LinePlot):
1474-
_kind = "area"
1497+
@property
1498+
def _kind(self) -> Literal["area"]:
1499+
return "area"
14751500

14761501
def __init__(self, data, **kwargs) -> None:
14771502
kwargs.setdefault("stacked", True)
@@ -1544,9 +1569,15 @@ def _post_plot_logic(self, ax: Axes, data):
15441569

15451570

15461571
class BarPlot(MPLPlot):
1547-
_kind = "bar"
1572+
@property
1573+
def _kind(self) -> Literal["bar", "barh"]:
1574+
return "bar"
1575+
15481576
_default_rot = 90
1549-
orientation = "vertical"
1577+
1578+
@property
1579+
def orientation(self) -> PlottingOrientation:
1580+
return "vertical"
15501581

15511582
def __init__(self, data, **kwargs) -> None:
15521583
# we have to treat a series differently than a
@@ -1698,9 +1729,15 @@ def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge):
16981729

16991730

17001731
class BarhPlot(BarPlot):
1701-
_kind = "barh"
1732+
@property
1733+
def _kind(self) -> Literal["barh"]:
1734+
return "barh"
1735+
17021736
_default_rot = 0
1703-
orientation = "horizontal"
1737+
1738+
@property
1739+
def orientation(self) -> Literal["horizontal"]:
1740+
return "horizontal"
17041741

17051742
@property
17061743
def _start_base(self):
@@ -1727,7 +1764,10 @@ def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge):
17271764

17281765

17291766
class PiePlot(MPLPlot):
1730-
_kind = "pie"
1767+
@property
1768+
def _kind(self) -> Literal["pie"]:
1769+
return "pie"
1770+
17311771
_layout_type = "horizontal"
17321772

17331773
def __init__(self, data, kind=None, **kwargs) -> None:

pandas/plotting/_matplotlib/hist.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import (
4+
TYPE_CHECKING,
5+
Literal,
6+
)
47

58
import numpy as np
69

10+
from pandas._typing import PlottingOrientation
11+
712
from pandas.core.dtypes.common import (
813
is_integer,
914
is_list_like,
@@ -40,7 +45,9 @@
4045

4146

4247
class HistPlot(LinePlot):
43-
_kind = "hist"
48+
@property
49+
def _kind(self) -> Literal["hist", "kde"]:
50+
return "hist"
4451

4552
def __init__(self, data, bins=10, bottom=0, **kwargs) -> None:
4653
self.bins = bins # use mpl default
@@ -64,8 +71,8 @@ def _args_adjust(self):
6471

6572
def _calculate_bins(self, data: DataFrame) -> np.ndarray:
6673
"""Calculate bins given data"""
67-
values = data._convert(datetime=True)._get_numeric_data()
68-
values = np.ravel(values)
74+
nd_values = data._convert(datetime=True)._get_numeric_data()
75+
values = np.ravel(nd_values)
6976
values = values[~isna(values)]
7077

7178
hist, bins = np.histogram(
@@ -159,16 +166,21 @@ def _post_plot_logic(self, ax: Axes, data):
159166
ax.set_ylabel("Frequency")
160167

161168
@property
162-
def orientation(self):
169+
def orientation(self) -> PlottingOrientation:
163170
if self.kwds.get("orientation", None) == "horizontal":
164171
return "horizontal"
165172
else:
166173
return "vertical"
167174

168175

169176
class KdePlot(HistPlot):
170-
_kind = "kde"
171-
orientation = "vertical"
177+
@property
178+
def _kind(self) -> Literal["kde"]:
179+
return "kde"
180+
181+
@property
182+
def orientation(self) -> Literal["vertical"]:
183+
return "vertical"
172184

173185
def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
174186
MPLPlot.__init__(self, data, **kwargs)

pandas/plotting/_matplotlib/style.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ def _get_colors_from_colormap(
143143
num_colors: int,
144144
) -> list[Color]:
145145
"""Get colors from colormap."""
146-
colormap = _get_cmap_instance(colormap)
147-
return [colormap(num) for num in np.linspace(0, 1, num=num_colors)]
146+
cmap = _get_cmap_instance(colormap)
147+
return [cmap(num) for num in np.linspace(0, 1, num=num_colors)]
148148

149149

150150
def _get_cmap_instance(colormap: str | Colormap) -> Colormap:

pandas/plotting/_matplotlib/timeseries.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from datetime import timedelta
56
import functools
67
from typing import (
78
TYPE_CHECKING,
@@ -185,11 +186,10 @@ def _get_ax_freq(ax: Axes):
185186
return ax_freq
186187

187188

188-
def _get_period_alias(freq) -> str | None:
189+
def _get_period_alias(freq: timedelta | BaseOffset | str) -> str | None:
189190
freqstr = to_offset(freq).rule_code
190191

191-
freq = get_period_alias(freqstr)
192-
return freq
192+
return get_period_alias(freqstr)
193193

194194

195195
def _get_freq(ax: Axes, series: Series):
@@ -235,7 +235,9 @@ def use_dynamic_x(ax: Axes, data: DataFrame | Series) -> bool:
235235
x = data.index
236236
if base <= FreqGroup.FR_DAY.value:
237237
return x[:1].is_normalized
238-
return Period(x[0], freq_str).to_timestamp().tz_localize(x.tz) == x[0]
238+
period = Period(x[0], freq_str)
239+
assert isinstance(period, Period)
240+
return period.to_timestamp().tz_localize(x.tz) == x[0]
239241
return True
240242

241243

pandas/plotting/_matplotlib/tools.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ def table(
8383
return table
8484

8585

86-
def _get_layout(nplots: int, layout=None, layout_type: str = "box") -> tuple[int, int]:
86+
def _get_layout(
87+
nplots: int,
88+
layout: tuple[int, int] | None = None,
89+
layout_type: str = "box",
90+
) -> tuple[int, int]:
8791
if layout is not None:
8892
if not isinstance(layout, (tuple, list)) or len(layout) != 2:
8993
raise ValueError("Layout must be a tuple of (rows, columns)")

0 commit comments

Comments
 (0)