@@ -212,6 +212,56 @@ def make_mapping(args, variable):
212
212
)
213
213
214
214
215
+ def lowess (options , x , y , x_label , y_label , non_missing ):
216
+ import statsmodels .api as sm
217
+
218
+ frac = options .get ("frac" , 0.6666666 )
219
+ # missing ='drop' is the default value for lowess but not for OLS (None)
220
+ # we force it here in case statsmodels change their defaults
221
+ y_out = sm .nonparametric .lowess (y , x , missing = "drop" , frac = frac )[:, 1 ]
222
+ hover_header = "<b>LOWESS trendline</b><br><br>"
223
+ return y_out , hover_header , None
224
+
225
+
226
+ def ma (options , x , y , x_label , y_label , non_missing ):
227
+ y_out = pd .Series (y , index = x ).rolling (** options ).mean ()[non_missing ]
228
+ hover_header = "<b>Moving Average trendline</b><br><br>"
229
+ return y_out , hover_header , None
230
+
231
+
232
+ def ewm (options , x , y , x_label , y_label , non_missing ):
233
+ y_out = pd .Series (y , index = x ).ewm (** options ).mean ()[non_missing ]
234
+ hover_header = "<b>EWM trendline</b><br><br>"
235
+ return y_out , hover_header , None
236
+
237
+
238
+ def ols (options , x , y , x_label , y_label , non_missing ):
239
+ import statsmodels .api as sm
240
+
241
+ add_constant = options .get ("add_constant" , True )
242
+ fit_results = sm .OLS (
243
+ y , sm .add_constant (x ) if add_constant else x , missing = "drop"
244
+ ).fit ()
245
+ y_out = fit_results .predict ()
246
+ hover_header = "<b>OLS trendline</b><br>"
247
+ if len (fit_results .params ) == 2 :
248
+ hover_header += "%s = %g * %s + %g<br>" % (
249
+ y_label ,
250
+ fit_results .params [1 ],
251
+ x_label ,
252
+ fit_results .params [0 ],
253
+ )
254
+ elif not add_constant :
255
+ hover_header += "%s = %g* %s<br>" % (y_label , fit_results .params [0 ], x_label ,)
256
+ else :
257
+ hover_header += "%s = %g<br>" % (y_label , fit_results .params [0 ],)
258
+ hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
259
+ return y_out , hover_header , fit_results
260
+
261
+
262
+ trendline_functions = dict (lowess = lowess , ma = ma , ewm = ewm , ols = ols )
263
+
264
+
215
265
def make_trace_kwargs (args , trace_spec , trace_data , mapping_labels , sizeref ):
216
266
"""Populates a dict with arguments to update trace
217
267
@@ -286,12 +336,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
286
336
mapping_labels ["count" ] = "%{x}"
287
337
elif attr_name == "trendline" :
288
338
if (
289
- attr_value [ 0 ] in [ "ols" , "lowess" , "ma" , "ewm" ]
339
+ attr_value in trendline_functions
290
340
and args ["x" ]
291
341
and args ["y" ]
292
342
and len (trace_data [[args ["x" ], args ["y" ]]].dropna ()) > 1
293
343
):
294
- import statsmodels .api as sm
295
344
296
345
# sorting is bad but trace_specs with "trendline" have no other attrs
297
346
sorted_trace_data = trace_data .sort_values (by = args ["x" ])
@@ -322,56 +371,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
322
371
np .logical_or (np .isnan (y ), np .isnan (x ))
323
372
)
324
373
trace_patch ["x" ] = sorted_trace_data [args ["x" ]][non_missing ]
325
-
326
- if attr_value [0 ] == "lowess" :
327
- alpha = attr_value [1 ] or 0.6666666
328
- # missing ='drop' is the default value for lowess but not for OLS (None)
329
- # we force it here in case statsmodels change their defaults
330
- trendline = sm .nonparametric .lowess (
331
- y , x , missing = "drop" , frac = alpha
332
- )
333
- trace_patch ["y" ] = trendline [:, 1 ]
334
- hover_header = "<b>LOWESS trendline</b><br><br>"
335
- elif attr_value [0 ] == "ma" :
336
- trace_patch ["y" ] = (
337
- pd .Series (y [non_missing ])
338
- .rolling (window = attr_value [1 ] or 3 )
339
- .mean ()
340
- )
341
- elif attr_value [0 ] == "ewm" :
342
- trace_patch ["y" ] = (
343
- pd .Series (y [non_missing ])
344
- .ewm (alpha = attr_value [1 ] or 0.5 )
345
- .mean ()
346
- )
347
- elif attr_value [0 ] == "ols" :
348
- add_constant = attr_value [1 ] is not False
349
- fit_results = sm .OLS (
350
- y , sm .add_constant (x ) if add_constant else x , missing = "drop"
351
- ).fit ()
352
- trace_patch ["y" ] = fit_results .predict ()
353
- hover_header = "<b>OLS trendline</b><br>"
354
- if len (fit_results .params ) == 2 :
355
- hover_header += "%s = %g * %s + %g<br>" % (
356
- args ["y" ],
357
- fit_results .params [1 ],
358
- args ["x" ],
359
- fit_results .params [0 ],
360
- )
361
- elif not add_constant :
362
- hover_header += "%s = %g* %s<br>" % (
363
- args ["y" ],
364
- fit_results .params [0 ],
365
- args ["x" ],
366
- )
367
- else :
368
- hover_header += "%s = %g<br>" % (
369
- args ["y" ],
370
- fit_results .params [0 ],
371
- )
372
- hover_header += (
373
- "R<sup>2</sup>=%f<br><br>" % fit_results .rsquared
374
- )
374
+ trendline_function = trendline_functions [attr_value ]
375
+ y_out , hover_header , fit_results = trendline_function (
376
+ args ["trendline_options" ],
377
+ x ,
378
+ y ,
379
+ args ["x" ],
380
+ args ["y" ],
381
+ non_missing ,
382
+ )
383
+ assert len (y_out ) == len (
384
+ trace_patch ["x" ]
385
+ ), "missing-data-handling failure in trendline code"
386
+ trace_patch ["y" ] = y_out
375
387
mapping_labels [get_label (args , args ["x" ])] = "%{x}"
376
388
mapping_labels [get_label (args , args ["y" ])] = "%{y} <b>(trend)</b>"
377
389
elif attr_name .startswith ("error" ):
@@ -1795,9 +1807,8 @@ def infer_config(args, constructor, trace_patch, layout_patch):
1795
1807
):
1796
1808
args ["facet_col_wrap" ] = 0
1797
1809
1798
- if args .get ("trendline" , None ) is not None :
1799
- if isinstance (args ["trendline" ], str ):
1800
- args ["trendline" ] = (args ["trendline" ], None )
1810
+ if "trendline_options" in args and args ["trendline_options" ] is None :
1811
+ args ["trendline_options" ] = dict ()
1801
1812
1802
1813
# Compute applicable grouping attributes
1803
1814
for k in group_attrables :
0 commit comments