Skip to content

Commit ed814d3

Browse files
docstrings
1 parent 310d16e commit ed814d3

File tree

2 files changed

+53
-18
lines changed

2 files changed

+53
-18
lines changed

Diff for: packages/python/plotly/plotly/express/_doc.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -384,18 +384,23 @@
384384
],
385385
trendline=[
386386
"str",
387-
"One of `'ols'` or `'lowess'`.",
387+
"One of `'ols'`, `'lowess'`, `'ma'` or `'ewma'`.",
388388
"If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.",
389389
"If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.",
390+
"If `'ma`', a Moving Average line will be drawn for each discrete-color/symbol group.",
391+
"If `'ewma`', an Exponentially Weighted Moving Average line will be drawn for each discrete-color/symbol group.",
392+
"See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how",
393+
"to configure them with the `trendline_options` argument.",
390394
],
391395
trendline_options=[
392396
"dict",
393-
"Options passed to the function named in the `trendline` argument.",
397+
"Options passed as the first argument to the function from `plotly.express.trendline_functions` ",
398+
"named in the `trendline` argument.",
394399
],
395400
trendline_color_override=[
396401
"str",
397402
"Valid CSS color.",
398-
"If provided, and if `trendline` is set, all trendlines will be drawn in this color.",
403+
"If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.",
399404
],
400405
render_mode=[
401406
"str",

Diff for: packages/python/plotly/plotly/express/trendline_functions/__init__.py

+45-15
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,47 @@
22
import numpy as np
33

44

5-
def ols(options, x_raw, x, y, x_label, y_label, non_missing):
5+
def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
6+
"""Ordinary Least Squares trendline function
7+
8+
Requires `statsmodels` to be installed.
9+
10+
Valid keys for the `trendline_options` dict are:
11+
12+
`add_constant` (`bool`, default `True`): if `False`, the trendline passes through
13+
the origin but if `True` a y-intercept is fitted.
14+
15+
`log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with
16+
respect to the base 10 logarithm of the input. Note that this means no zeros can
17+
be present in the input.
18+
"""
19+
620
import statsmodels.api as sm
721

8-
add_constant = options.get("add_constant", True)
9-
log_x = options.get("log_x", False)
10-
log_y = options.get("log_y", False)
22+
add_constant = trendline_options.get("add_constant", True)
23+
log_x = trendline_options.get("log_x", False)
24+
log_y = trendline_options.get("log_y", False)
1125

1226
if log_y:
1327
if np.any(y == 0):
1428
raise ValueError(
1529
"Can't do OLS trendline with `log_y=True` when `y` contains zeros."
1630
)
17-
y = np.log(y)
18-
y_label = "log(%s)" % y_label
31+
y = np.log10(y)
32+
y_label = "log10(%s)" % y_label
1933
if log_x:
2034
if np.any(x == 0):
2135
raise ValueError(
2236
"Can't do OLS trendline with `log_x=True` when `x` contains zeros."
2337
)
24-
x = np.log(x)
25-
x_label = "log(%s)" % x_label
38+
x = np.log10(x)
39+
x_label = "log10(%s)" % x_label
2640
if add_constant:
2741
x = sm.add_constant(x)
2842
fit_results = sm.OLS(y, x, missing="drop").fit()
2943
y_out = fit_results.predict()
3044
if log_y:
31-
y_out = np.exp(y_out)
45+
y_out = np.power(10, y_out)
3246
hover_header = "<b>OLS trendline</b><br>"
3347
if len(fit_results.params) == 2:
3448
hover_header += "%s = %g * %s + %g<br>" % (
@@ -45,22 +59,38 @@ def ols(options, x_raw, x, y, x_label, y_label, non_missing):
4559
return y_out, hover_header, fit_results
4660

4761

48-
def lowess(options, x_raw, x, y, x_label, y_label, non_missing):
62+
def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
63+
"""Locally Weighted Scatterplot Smoothing trendline function
64+
65+
Requires `statsmodels` to be installed.
66+
67+
Valid keys for the `trendline_options` dict are:
68+
69+
`frac` (`float`, default `0.6666666`): the `frac` parameter from `statsmodels.api.nonparametric.lowess`
70+
"""
4971
import statsmodels.api as sm
5072

51-
frac = options.get("frac", 0.6666666)
73+
frac = trendline_options.get("frac", 0.6666666)
5274
y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
5375
hover_header = "<b>LOWESS trendline</b><br><br>"
5476
return y_out, hover_header, None
5577

5678

57-
def ma(options, x_raw, x, y, x_label, y_label, non_missing):
58-
y_out = pd.Series(y, index=x_raw).rolling(**options).mean()[non_missing]
79+
def ma(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
80+
"""Moving Average trendline function
81+
82+
The `trendline_options` dict is passed as keyword arguments into the `pandas.Series.rolling` function.
83+
"""
84+
y_out = pd.Series(y, index=x_raw).rolling(**trendline_options).mean()[non_missing]
5985
hover_header = "<b>MA trendline</b><br><br>"
6086
return y_out, hover_header, None
6187

6288

63-
def ewma(options, x_raw, x, y, x_label, y_label, non_missing):
64-
y_out = pd.Series(y, index=x_raw).ewm(**options).mean()[non_missing]
89+
def ewma(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
90+
"""Exponentially Weighted Moving Average trendline function
91+
92+
The `trendline_options` dict is passed as keyword arguments into the `pandas.Series.ewma` function.
93+
"""
94+
y_out = pd.Series(y, index=x_raw).ewm(**trendline_options).mean()[non_missing]
6595
hover_header = "<b>EWMA trendline</b><br><br>"
6696
return y_out, hover_header, None

0 commit comments

Comments
 (0)