Skip to content

Commit b5d5611

Browse files
move trendline code to own module
1 parent f036e5c commit b5d5611

File tree

4 files changed

+62
-60
lines changed

4 files changed

+62
-60
lines changed

Diff for: doc/apidoc/plotly.express.rst

+1
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ plotly's high-level API for rapid figure generation. ::
5959

6060
generated/plotly.express.data.rst
6161
generated/plotly.express.colors.rst
62+
generated/plotly.express.trendline_functions.rst

Diff for: packages/python/plotly/plotly/express/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
from ._special_inputs import IdentityMap, Constant, Range # noqa: F401
6161

62-
from . import data, colors # noqa: F401
62+
from . import data, colors, trendline_functions # noqa: F401
6363

6464
__all__ = [
6565
"scatter",

Diff for: packages/python/plotly/plotly/express/_core.py

+2-59
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import plotly.io as pio
33
from collections import namedtuple, OrderedDict
44
from ._special_inputs import IdentityMap, Constant, Range
5+
from .trendline_functions import ols, lowess, ma, ewm
56

67
from _plotly_utils.basevalidators import ColorscaleValidator
78
from plotly.colors import qualitative, sequential
@@ -212,65 +213,6 @@ def make_mapping(args, variable):
212213
)
213214

214215

215-
def lowess(options, x, y, x_label, y_label, non_missing):
216-
import statsmodels.api as sm
217-
218-
frac = options.get("frac", 0.6666666)
219-
# missing ='drop' is the default value for lowess but not for OLS (None)
220-
# we force it here in case statsmodels change their defaults
221-
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
222-
hover_header = "<b>LOWESS trendline</b><br><br>"
223-
return y_out, hover_header, None
224-
225-
226-
def ma(options, x, y, x_label, y_label, non_missing):
227-
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
228-
hover_header = "<b>Moving Average trendline</b><br><br>"
229-
return y_out, hover_header, None
230-
231-
232-
def ewm(options, x, y, x_label, y_label, non_missing):
233-
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
234-
hover_header = "<b>EWM trendline</b><br><br>"
235-
return y_out, hover_header, None
236-
237-
238-
def ols(options, x, y, x_label, y_label, non_missing):
239-
import statsmodels.api as sm
240-
241-
add_constant = options.get("add_constant", True)
242-
log_x = options.get("log_x", False)
243-
log_y = options.get("log_y", False)
244-
245-
if log_y:
246-
y = np.log(y)
247-
if log_x:
248-
x = np.log(x)
249-
if add_constant:
250-
x = sm.add_constant(x)
251-
fit_results = sm.OLS(y, x, missing="drop").fit()
252-
y_out = fit_results.predict()
253-
if log_y:
254-
y_out = np.exp(y_out)
255-
hover_header = "<b>OLS trendline</b><br>"
256-
if len(fit_results.params) == 2:
257-
hover_header += "%s = %g * %s + %g<br>" % (
258-
y_label,
259-
fit_results.params[1],
260-
x_label,
261-
fit_results.params[0],
262-
)
263-
elif not add_constant:
264-
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
265-
else:
266-
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
267-
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
268-
return y_out, hover_header, fit_results
269-
270-
271-
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
272-
273-
274216
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
275217
"""Populates a dict with arguments to update trace
276218
@@ -344,6 +286,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
344286
if trace_spec.constructor == go.Histogram:
345287
mapping_labels["count"] = "%{x}"
346288
elif attr_name == "trendline":
289+
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
347290
if (
348291
attr_value in trendline_functions
349292
and args["x"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pandas as pd
2+
import numpy as np
3+
4+
5+
def ols(options, x, y, x_label, y_label, non_missing):
6+
import statsmodels.api as sm
7+
8+
add_constant = options.get("add_constant", True)
9+
log_x = options.get("log_x", False)
10+
log_y = options.get("log_y", False)
11+
12+
if log_y:
13+
y = np.log(y)
14+
y_label = "log(%s)" % y_label
15+
if log_x:
16+
x = np.log(x)
17+
x_label = "log(%s)" % x_label
18+
if add_constant:
19+
x = sm.add_constant(x)
20+
fit_results = sm.OLS(y, x, missing="drop").fit()
21+
y_out = fit_results.predict()
22+
if log_y:
23+
y_out = np.exp(y_out)
24+
hover_header = "<b>OLS trendline</b><br>"
25+
if len(fit_results.params) == 2:
26+
hover_header += "%s = %g * %s + %g<br>" % (
27+
y_label,
28+
fit_results.params[1],
29+
x_label,
30+
fit_results.params[0],
31+
)
32+
elif not add_constant:
33+
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
34+
else:
35+
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
36+
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
37+
return y_out, hover_header, fit_results
38+
39+
40+
def lowess(options, x, y, x_label, y_label, non_missing):
41+
import statsmodels.api as sm
42+
43+
frac = options.get("frac", 0.6666666)
44+
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
45+
hover_header = "<b>LOWESS trendline</b><br><br>"
46+
return y_out, hover_header, None
47+
48+
49+
def ma(options, x, y, x_label, y_label, non_missing):
50+
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
51+
hover_header = "<b>Moving Average trendline</b><br><br>"
52+
return y_out, hover_header, None
53+
54+
55+
def ewm(options, x, y, x_label, y_label, non_missing):
56+
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
57+
hover_header = "<b>EWM trendline</b><br><br>"
58+
return y_out, hover_header, None

0 commit comments

Comments
 (0)