Skip to content

Commit ded9971

Browse files
ma and ewm trendlines
1 parent c4fef05 commit ded9971

File tree

1 file changed

+35
-8
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+35
-8
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
286286
mapping_labels["count"] = "%{x}"
287287
elif attr_name == "trendline":
288288
if (
289-
attr_value in ["ols", "lowess"]
289+
attr_value[0] in ["ols", "lowess", "ma", "ewm"]
290290
and args["x"]
291291
and args["y"]
292292
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
@@ -318,19 +318,36 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
318318
)
319319

320320
# preserve original values of "x" in case they're dates
321-
trace_patch["x"] = sorted_trace_data[args["x"]][
322-
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
323-
]
321+
non_missing = np.logical_not(
322+
np.logical_or(np.isnan(y), np.isnan(x))
323+
)
324+
trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
324325

325-
if attr_value == "lowess":
326+
if attr_value[0] == "lowess":
327+
alpha = attr_value[1] or 0.6666666
326328
# missing ='drop' is the default value for lowess but not for OLS (None)
327329
# we force it here in case statsmodels change their defaults
328-
trendline = sm.nonparametric.lowess(y, x, missing="drop")
330+
trendline = sm.nonparametric.lowess(
331+
y, x, missing="drop", frac=alpha
332+
)
329333
trace_patch["y"] = trendline[:, 1]
330334
hover_header = "<b>LOWESS trendline</b><br><br>"
331-
elif attr_value == "ols":
335+
elif attr_value[0] == "ma":
336+
trace_patch["y"] = (
337+
pd.Series(y[non_missing])
338+
.rolling(window=attr_value[1] or 3)
339+
.mean()
340+
)
341+
elif attr_value[0] == "ewm":
342+
trace_patch["y"] = (
343+
pd.Series(y[non_missing])
344+
.ewm(alpha=attr_value[1] or 0.5)
345+
.mean()
346+
)
347+
elif attr_value[0] == "ols":
348+
add_constant = attr_value[1] is not False
332349
fit_results = sm.OLS(
333-
y, sm.add_constant(x), missing="drop"
350+
y, sm.add_constant(x) if add_constant else x, missing="drop"
334351
).fit()
335352
trace_patch["y"] = fit_results.predict()
336353
hover_header = "<b>OLS trendline</b><br>"
@@ -341,6 +358,12 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
341358
args["x"],
342359
fit_results.params[0],
343360
)
361+
elif not add_constant:
362+
hover_header += "%s = %g* %s<br>" % (
363+
args["y"],
364+
fit_results.params[0],
365+
args["x"],
366+
)
344367
else:
345368
hover_header += "%s = %g<br>" % (
346369
args["y"],
@@ -1772,6 +1795,10 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17721795
):
17731796
args["facet_col_wrap"] = 0
17741797

1798+
if args.get("trendline", None) is not None:
1799+
if isinstance(args["trendline"], str):
1800+
args["trendline"] = (args["trendline"], None)
1801+
17751802
# Compute applicable grouping attributes
17761803
for k in group_attrables:
17771804
if k in args:

0 commit comments

Comments
 (0)