|
2 | 2 | import plotly.io as pio
|
3 | 3 | from collections import namedtuple, OrderedDict
|
4 | 4 | from ._special_inputs import IdentityMap, Constant, Range
|
| 5 | +from .trendline_functions import ols, lowess, ma, ewm |
5 | 6 |
|
6 | 7 | from _plotly_utils.basevalidators import ColorscaleValidator
|
7 | 8 | from plotly.colors import qualitative, sequential
|
@@ -212,65 +213,6 @@ def make_mapping(args, variable):
|
212 | 213 | )
|
213 | 214 |
|
214 | 215 |
|
215 |
| -def lowess(options, x, y, x_label, y_label, non_missing): |
216 |
| - import statsmodels.api as sm |
217 |
| - |
218 |
| - frac = options.get("frac", 0.6666666) |
219 |
| - # missing ='drop' is the default value for lowess but not for OLS (None) |
220 |
| - # we force it here in case statsmodels change their defaults |
221 |
| - y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] |
222 |
| - hover_header = "<b>LOWESS trendline</b><br><br>" |
223 |
| - return y_out, hover_header, None |
224 |
| - |
225 |
| - |
226 |
| -def ma(options, x, y, x_label, y_label, non_missing): |
227 |
| - y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing] |
228 |
| - hover_header = "<b>Moving Average trendline</b><br><br>" |
229 |
| - return y_out, hover_header, None |
230 |
| - |
231 |
| - |
232 |
| -def ewm(options, x, y, x_label, y_label, non_missing): |
233 |
| - y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing] |
234 |
| - hover_header = "<b>EWM trendline</b><br><br>" |
235 |
| - return y_out, hover_header, None |
236 |
| - |
237 |
| - |
238 |
| -def ols(options, x, y, x_label, y_label, non_missing): |
239 |
| - import statsmodels.api as sm |
240 |
| - |
241 |
| - add_constant = options.get("add_constant", True) |
242 |
| - log_x = options.get("log_x", False) |
243 |
| - log_y = options.get("log_y", False) |
244 |
| - |
245 |
| - if log_y: |
246 |
| - y = np.log(y) |
247 |
| - if log_x: |
248 |
| - x = np.log(x) |
249 |
| - if add_constant: |
250 |
| - x = sm.add_constant(x) |
251 |
| - fit_results = sm.OLS(y, x, missing="drop").fit() |
252 |
| - y_out = fit_results.predict() |
253 |
| - if log_y: |
254 |
| - y_out = np.exp(y_out) |
255 |
| - hover_header = "<b>OLS trendline</b><br>" |
256 |
| - if len(fit_results.params) == 2: |
257 |
| - hover_header += "%s = %g * %s + %g<br>" % ( |
258 |
| - y_label, |
259 |
| - fit_results.params[1], |
260 |
| - x_label, |
261 |
| - fit_results.params[0], |
262 |
| - ) |
263 |
| - elif not add_constant: |
264 |
| - hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,) |
265 |
| - else: |
266 |
| - hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],) |
267 |
| - hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared |
268 |
| - return y_out, hover_header, fit_results |
269 |
| - |
270 |
| - |
271 |
| -trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
272 |
| - |
273 |
| - |
274 | 216 | def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
|
275 | 217 | """Populates a dict with arguments to update trace
|
276 | 218 |
|
@@ -344,6 +286,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
|
344 | 286 | if trace_spec.constructor == go.Histogram:
|
345 | 287 | mapping_labels["count"] = "%{x}"
|
346 | 288 | elif attr_name == "trendline":
|
| 289 | + trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols) |
347 | 290 | if (
|
348 | 291 | attr_value in trendline_functions
|
349 | 292 | and args["x"]
|
|
0 commit comments