Skip to content

Commit 1dc6ab3

Browse files
extract trendline function API
1 parent ded9971 commit 1dc6ab3

File tree

3 files changed

+72
-55
lines changed

3 files changed

+72
-55
lines changed

Diff for: packages/python/plotly/plotly/express/_chart_types.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def scatter(
4646
marginal_x=None,
4747
marginal_y=None,
4848
trendline=None,
49+
trendline_options=None,
4950
trendline_color_override=None,
5051
log_x=False,
5152
log_y=False,
@@ -90,6 +91,7 @@ def density_contour(
9091
marginal_x=None,
9192
marginal_y=None,
9293
trendline=None,
94+
trendline_options=None,
9395
trendline_color_override=None,
9496
log_x=False,
9597
log_y=False,

Diff for: packages/python/plotly/plotly/express/_core.py

+66-55
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,56 @@ def make_mapping(args, variable):
212212
)
213213

214214

215+
def lowess(options, x, y, x_label, y_label, non_missing):
216+
import statsmodels.api as sm
217+
218+
frac = options.get("frac", 0.6666666)
219+
# missing ='drop' is the default value for lowess but not for OLS (None)
220+
# we force it here in case statsmodels change their defaults
221+
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
222+
hover_header = "<b>LOWESS trendline</b><br><br>"
223+
return y_out, hover_header, None
224+
225+
226+
def ma(options, x, y, x_label, y_label, non_missing):
227+
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
228+
hover_header = "<b>Moving Average trendline</b><br><br>"
229+
return y_out, hover_header, None
230+
231+
232+
def ewm(options, x, y, x_label, y_label, non_missing):
233+
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
234+
hover_header = "<b>EWM trendline</b><br><br>"
235+
return y_out, hover_header, None
236+
237+
238+
def ols(options, x, y, x_label, y_label, non_missing):
239+
import statsmodels.api as sm
240+
241+
add_constant = options.get("add_constant", True)
242+
fit_results = sm.OLS(
243+
y, sm.add_constant(x) if add_constant else x, missing="drop"
244+
).fit()
245+
y_out = fit_results.predict()
246+
hover_header = "<b>OLS trendline</b><br>"
247+
if len(fit_results.params) == 2:
248+
hover_header += "%s = %g * %s + %g<br>" % (
249+
y_label,
250+
fit_results.params[1],
251+
x_label,
252+
fit_results.params[0],
253+
)
254+
elif not add_constant:
255+
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
256+
else:
257+
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
258+
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
259+
return y_out, hover_header, fit_results
260+
261+
262+
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
263+
264+
215265
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
216266
"""Populates a dict with arguments to update trace
217267
@@ -286,12 +336,11 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
286336
mapping_labels["count"] = "%{x}"
287337
elif attr_name == "trendline":
288338
if (
289-
attr_value[0] in ["ols", "lowess", "ma", "ewm"]
339+
attr_value in trendline_functions
290340
and args["x"]
291341
and args["y"]
292342
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
293343
):
294-
import statsmodels.api as sm
295344

296345
# sorting is bad but trace_specs with "trendline" have no other attrs
297346
sorted_trace_data = trace_data.sort_values(by=args["x"])
@@ -322,56 +371,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
322371
np.logical_or(np.isnan(y), np.isnan(x))
323372
)
324373
trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
325-
326-
if attr_value[0] == "lowess":
327-
alpha = attr_value[1] or 0.6666666
328-
# missing ='drop' is the default value for lowess but not for OLS (None)
329-
# we force it here in case statsmodels change their defaults
330-
trendline = sm.nonparametric.lowess(
331-
y, x, missing="drop", frac=alpha
332-
)
333-
trace_patch["y"] = trendline[:, 1]
334-
hover_header = "<b>LOWESS trendline</b><br><br>"
335-
elif attr_value[0] == "ma":
336-
trace_patch["y"] = (
337-
pd.Series(y[non_missing])
338-
.rolling(window=attr_value[1] or 3)
339-
.mean()
340-
)
341-
elif attr_value[0] == "ewm":
342-
trace_patch["y"] = (
343-
pd.Series(y[non_missing])
344-
.ewm(alpha=attr_value[1] or 0.5)
345-
.mean()
346-
)
347-
elif attr_value[0] == "ols":
348-
add_constant = attr_value[1] is not False
349-
fit_results = sm.OLS(
350-
y, sm.add_constant(x) if add_constant else x, missing="drop"
351-
).fit()
352-
trace_patch["y"] = fit_results.predict()
353-
hover_header = "<b>OLS trendline</b><br>"
354-
if len(fit_results.params) == 2:
355-
hover_header += "%s = %g * %s + %g<br>" % (
356-
args["y"],
357-
fit_results.params[1],
358-
args["x"],
359-
fit_results.params[0],
360-
)
361-
elif not add_constant:
362-
hover_header += "%s = %g* %s<br>" % (
363-
args["y"],
364-
fit_results.params[0],
365-
args["x"],
366-
)
367-
else:
368-
hover_header += "%s = %g<br>" % (
369-
args["y"],
370-
fit_results.params[0],
371-
)
372-
hover_header += (
373-
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
374-
)
374+
trendline_function = trendline_functions[attr_value]
375+
y_out, hover_header, fit_results = trendline_function(
376+
args["trendline_options"],
377+
x,
378+
y,
379+
args["x"],
380+
args["y"],
381+
non_missing,
382+
)
383+
assert len(y_out) == len(
384+
trace_patch["x"]
385+
), "missing-data-handling failure in trendline code"
386+
trace_patch["y"] = y_out
375387
mapping_labels[get_label(args, args["x"])] = "%{x}"
376388
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
377389
elif attr_name.startswith("error"):
@@ -1795,9 +1807,8 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17951807
):
17961808
args["facet_col_wrap"] = 0
17971809

1798-
if args.get("trendline", None) is not None:
1799-
if isinstance(args["trendline"], str):
1800-
args["trendline"] = (args["trendline"], None)
1810+
if "trendline_options" in args and args["trendline_options"] is None:
1811+
args["trendline_options"] = dict()
18011812

18021813
# Compute applicable grouping attributes
18031814
for k in group_attrables:

Diff for: packages/python/plotly/plotly/express/_doc.py

+4
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@
388388
"If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.",
389389
"If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.",
390390
],
391+
trendline_options=[
392+
"dict",
393+
"Options passed to the function named in the `trendline` argument.",
394+
],
391395
trendline_color_override=[
392396
"str",
393397
"Valid CSS color.",

0 commit comments

Comments
 (0)