2
2
3
3
from typing import (
4
4
TYPE_CHECKING ,
5
+ Any ,
5
6
Literal ,
7
+ final ,
6
8
)
7
9
8
10
import numpy as np
@@ -58,13 +60,15 @@ def __init__(
58
60
bottom : int | np .ndarray = 0 ,
59
61
* ,
60
62
range = None ,
63
+ weights = None ,
61
64
** kwargs ,
62
65
) -> None :
63
66
if is_list_like (bottom ):
64
67
bottom = np .array (bottom )
65
68
self .bottom = bottom
66
69
67
70
self ._bin_range = range
71
+ self .weights = weights
68
72
69
73
self .xlabel = kwargs .get ("xlabel" )
70
74
self .ylabel = kwargs .get ("ylabel" )
@@ -96,7 +100,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
96
100
@classmethod
97
101
def _plot ( # type: ignore[override]
98
102
cls ,
99
- ax ,
103
+ ax : Axes ,
100
104
y ,
101
105
style = None ,
102
106
bottom : int | np .ndarray = 0 ,
@@ -140,7 +144,7 @@ def _make_plot(self, fig: Figure) -> None:
140
144
if style is not None :
141
145
kwds ["style" ] = style
142
146
143
- kwds = self ._make_plot_keywords (kwds , y )
147
+ self ._make_plot_keywords (kwds , y )
144
148
145
149
# the bins is multi-dimension array now and each plot need only 1-d and
146
150
# when by is applied, label should be columns that are grouped
@@ -149,21 +153,8 @@ def _make_plot(self, fig: Figure) -> None:
149
153
kwds ["label" ] = self .columns
150
154
kwds .pop ("color" )
151
155
152
- # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
153
- # and each sub-array (10,) will be called in each iteration. If users only
154
- # provide 1D array, we assume the same weights is used for all iterations
155
- weights = kwds .get ("weights" , None )
156
- if weights is not None :
157
- if np .ndim (weights ) != 1 and np .shape (weights )[- 1 ] != 1 :
158
- try :
159
- weights = weights [:, i ]
160
- except IndexError as err :
161
- raise ValueError (
162
- "weights must have the same shape as data, "
163
- "or be a single column"
164
- ) from err
165
- weights = weights [~ isna (y )]
166
- kwds ["weights" ] = weights
156
+ if self .weights is not None :
157
+ kwds ["weights" ] = self ._get_column_weights (self .weights , i , y )
167
158
168
159
y = reformat_hist_y_given_by (y , self .by )
169
160
@@ -175,12 +166,29 @@ def _make_plot(self, fig: Figure) -> None:
175
166
176
167
self ._append_legend_handles_labels (artists [0 ], label )
177
168
178
- def _make_plot_keywords (self , kwds , y ) :
169
+ def _make_plot_keywords (self , kwds : dict [ str , Any ], y ) -> None :
179
170
"""merge BoxPlot/KdePlot properties to passed kwds"""
180
171
# y is required for KdePlot
181
172
kwds ["bottom" ] = self .bottom
182
173
kwds ["bins" ] = self .bins
183
- return kwds
174
+
175
+ @final
176
+ @staticmethod
177
+ def _get_column_weights (weights , i : int , y ):
178
+ # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
179
+ # and each sub-array (10,) will be called in each iteration. If users only
180
+ # provide 1D array, we assume the same weights is used for all iterations
181
+ if weights is not None :
182
+ if np .ndim (weights ) != 1 and np .shape (weights )[- 1 ] != 1 :
183
+ try :
184
+ weights = weights [:, i ]
185
+ except IndexError as err :
186
+ raise ValueError (
187
+ "weights must have the same shape as data, "
188
+ "or be a single column"
189
+ ) from err
190
+ weights = weights [~ isna (y )]
191
+ return weights
184
192
185
193
def _post_plot_logic (self , ax : Axes , data ) -> None :
186
194
if self .orientation == "horizontal" :
@@ -207,11 +215,14 @@ def _kind(self) -> Literal["kde"]:
207
215
def orientation (self ) -> Literal ["vertical" ]:
208
216
return "vertical"
209
217
210
- def __init__ (self , data , bw_method = None , ind = None , ** kwargs ) -> None :
218
+ def __init__ (
219
+ self , data , bw_method = None , ind = None , * , weights = None , ** kwargs
220
+ ) -> None :
211
221
# Do not call LinePlot.__init__ which may fill nan
212
222
MPLPlot .__init__ (self , data , ** kwargs ) # pylint: disable=non-parent-init-called
213
223
self .bw_method = bw_method
214
224
self .ind = ind
225
+ self .weights = weights
215
226
216
227
@staticmethod
217
228
def _get_ind (y , ind ):
@@ -233,9 +244,10 @@ def _get_ind(y, ind):
233
244
return ind
234
245
235
246
@classmethod
236
- def _plot (
247
+ # error: Signature of "_plot" incompatible with supertype "MPLPlot"
248
+ def _plot ( # type: ignore[override]
237
249
cls ,
238
- ax ,
250
+ ax : Axes ,
239
251
y ,
240
252
style = None ,
241
253
bw_method = None ,
@@ -253,10 +265,9 @@ def _plot(
253
265
lines = MPLPlot ._plot (ax , ind , y , style = style , ** kwds )
254
266
return lines
255
267
256
- def _make_plot_keywords (self , kwds , y ) :
268
+ def _make_plot_keywords (self , kwds : dict [ str , Any ], y ) -> None :
257
269
kwds ["bw_method" ] = self .bw_method
258
270
kwds ["ind" ] = self ._get_ind (y , ind = self .ind )
259
- return kwds
260
271
261
272
def _post_plot_logic (self , ax , data ) -> None :
262
273
ax .set_ylabel ("Density" )
0 commit comments