Skip to content

Commit a167f13

Browse files
authored
TYP: plotting, make weights kwd explicit (#55877)
* TYP: plotting * TYP: plotting * Make weights kwd explicit
1 parent 7ae6b8e commit a167f13

File tree

3 files changed

+40
-26
lines changed

3 files changed

+40
-26
lines changed

pandas/plotting/_matplotlib/boxplot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
9393
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
9494
@classmethod
9595
def _plot( # type: ignore[override]
96-
cls, ax, y, column_num=None, return_type: str = "axes", **kwds
96+
cls, ax: Axes, y, column_num=None, return_type: str = "axes", **kwds
9797
):
9898
if y.ndim == 2:
9999
y = [remove_na_arraylike(v) for v in y]

pandas/plotting/_matplotlib/core.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from typing import (
1313
TYPE_CHECKING,
14+
Any,
1415
Literal,
1516
final,
1617
)
@@ -998,7 +999,9 @@ def on_right(self, i: int):
998999
return self.data.columns[i] in self.secondary_y
9991000

10001001
@final
1001-
def _apply_style_colors(self, colors, kwds, col_num, label: str):
1002+
def _apply_style_colors(
1003+
self, colors, kwds: dict[str, Any], col_num: int, label: str
1004+
):
10021005
"""
10031006
Manage style and color based on column number and its label.
10041007
Returns tuple of appropriate style and kwds which "color" may be added.

pandas/plotting/_matplotlib/hist.py

+35-24
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from typing import (
44
TYPE_CHECKING,
5+
Any,
56
Literal,
7+
final,
68
)
79

810
import numpy as np
@@ -58,13 +60,15 @@ def __init__(
5860
bottom: int | np.ndarray = 0,
5961
*,
6062
range=None,
63+
weights=None,
6164
**kwargs,
6265
) -> None:
6366
if is_list_like(bottom):
6467
bottom = np.array(bottom)
6568
self.bottom = bottom
6669

6770
self._bin_range = range
71+
self.weights = weights
6872

6973
self.xlabel = kwargs.get("xlabel")
7074
self.ylabel = kwargs.get("ylabel")
@@ -96,7 +100,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
96100
@classmethod
97101
def _plot( # type: ignore[override]
98102
cls,
99-
ax,
103+
ax: Axes,
100104
y,
101105
style=None,
102106
bottom: int | np.ndarray = 0,
@@ -140,7 +144,7 @@ def _make_plot(self, fig: Figure) -> None:
140144
if style is not None:
141145
kwds["style"] = style
142146

143-
kwds = self._make_plot_keywords(kwds, y)
147+
self._make_plot_keywords(kwds, y)
144148

145149
# the bins is multi-dimension array now and each plot need only 1-d and
146150
# when by is applied, label should be columns that are grouped
@@ -149,21 +153,8 @@ def _make_plot(self, fig: Figure) -> None:
149153
kwds["label"] = self.columns
150154
kwds.pop("color")
151155

152-
# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
153-
# and each sub-array (10,) will be called in each iteration. If users only
154-
# provide 1D array, we assume the same weights is used for all iterations
155-
weights = kwds.get("weights", None)
156-
if weights is not None:
157-
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
158-
try:
159-
weights = weights[:, i]
160-
except IndexError as err:
161-
raise ValueError(
162-
"weights must have the same shape as data, "
163-
"or be a single column"
164-
) from err
165-
weights = weights[~isna(y)]
166-
kwds["weights"] = weights
156+
if self.weights is not None:
157+
kwds["weights"] = self._get_column_weights(self.weights, i, y)
167158

168159
y = reformat_hist_y_given_by(y, self.by)
169160

@@ -175,12 +166,29 @@ def _make_plot(self, fig: Figure) -> None:
175166

176167
self._append_legend_handles_labels(artists[0], label)
177168

178-
def _make_plot_keywords(self, kwds, y):
169+
def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None:
179170
"""merge BoxPlot/KdePlot properties to passed kwds"""
180171
# y is required for KdePlot
181172
kwds["bottom"] = self.bottom
182173
kwds["bins"] = self.bins
183-
return kwds
174+
175+
@final
176+
@staticmethod
177+
def _get_column_weights(weights, i: int, y):
178+
# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
179+
# and each sub-array (10,) will be called in each iteration. If users only
180+
# provide 1D array, we assume the same weights is used for all iterations
181+
if weights is not None:
182+
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
183+
try:
184+
weights = weights[:, i]
185+
except IndexError as err:
186+
raise ValueError(
187+
"weights must have the same shape as data, "
188+
"or be a single column"
189+
) from err
190+
weights = weights[~isna(y)]
191+
return weights
184192

185193
def _post_plot_logic(self, ax: Axes, data) -> None:
186194
if self.orientation == "horizontal":
@@ -207,11 +215,14 @@ def _kind(self) -> Literal["kde"]:
207215
def orientation(self) -> Literal["vertical"]:
208216
return "vertical"
209217

210-
def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
218+
def __init__(
219+
self, data, bw_method=None, ind=None, *, weights=None, **kwargs
220+
) -> None:
211221
# Do not call LinePlot.__init__ which may fill nan
212222
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
213223
self.bw_method = bw_method
214224
self.ind = ind
225+
self.weights = weights
215226

216227
@staticmethod
217228
def _get_ind(y, ind):
@@ -233,9 +244,10 @@ def _get_ind(y, ind):
233244
return ind
234245

235246
@classmethod
236-
def _plot(
247+
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
248+
def _plot( # type: ignore[override]
237249
cls,
238-
ax,
250+
ax: Axes,
239251
y,
240252
style=None,
241253
bw_method=None,
@@ -253,10 +265,9 @@ def _plot(
253265
lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
254266
return lines
255267

256-
def _make_plot_keywords(self, kwds, y):
268+
def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None:
257269
kwds["bw_method"] = self.bw_method
258270
kwds["ind"] = self._get_ind(y, ind=self.ind)
259-
return kwds
260271

261272
def _post_plot_logic(self, ax, data) -> None:
262273
ax.set_ylabel("Density")

0 commit comments

Comments
 (0)