Skip to content

Commit 3deb9d1

Browse files
fix up tests
1 parent d156aee commit 3deb9d1

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
NO_COLOR = "px_no_color_constant"
20+
trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols)
2021

2122
# Declare all supported attributes, across all plot types
2223
direct_attrables = (
@@ -286,10 +287,8 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
286287
if trace_spec.constructor == go.Histogram:
287288
mapping_labels["count"] = "%{x}"
288289
elif attr_name == "trendline":
289-
trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols)
290290
if (
291-
attr_value in trendline_functions
292-
and args["x"]
291+
args["x"]
293292
and args["y"]
294293
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
295294
):
@@ -1760,6 +1759,13 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17601759
):
17611760
args["facet_col_wrap"] = 0
17621761

1762+
if "trendline" in args and args["trendline"] is not None:
1763+
if args["trendline"] not in trendline_functions:
1764+
raise ValueError(
1765+
"Value '%s' for `trendline` must be one of %s"
1766+
% (args["trendline"], trendline_functions.keys())
1767+
)
1768+
17631769
if "trendline_options" in args and args["trendline_options"] is None:
17641770
args["trendline_options"] = dict()
17651771

Diff for: packages/python/plotly/plotly/express/trendline_functions/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@ def ols(options, x_raw, x, y, x_label, y_label, non_missing):
1010
log_y = options.get("log_y", False)
1111

1212
if log_y:
13+
if np.any(y == 0):
14+
raise ValueError(
15+
"Can't do OLS trendline with `log_y=True` when `y` contains zeros."
16+
)
1317
y = np.log(y)
1418
y_label = "log(%s)" % y_label
1519
if log_x:
20+
if np.any(x == 0):
21+
raise ValueError(
22+
"Can't do OLS trendline with `log_x=True` when `x` contains zeros."
23+
)
1624
x = np.log(x)
1725
x_label = "log(%s)" % x_label
1826
if add_constant:

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_trendline_enough_values(mode, options):
9999
"mode,options",
100100
[
101101
("ols", None),
102+
("ols", dict(add_constant=False, log_x=True, log_y=True)),
102103
("lowess", None),
103104
("lowess", dict(frac=0.3)),
104105
("ma", dict(window=2)),
@@ -124,7 +125,8 @@ def test_trendline_nan_values(mode, options):
124125

125126
def test_ols_trendline_slopes():
126127
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
127-
assert "y = 1 * x + 0<br>" in fig.data[1].hovertemplate
128+
# should be "y = 1 * x + 0" but sometimes is some tiny number instead
129+
assert "y = 1 * x + " in fig.data[1].hovertemplate
128130
results = px.get_trendline_results(fig)
129131
params = results["px_fit_results"].iloc[0].params
130132
assert np.all(np.isclose(params, [0, 1]))

0 commit comments

Comments
 (0)