Skip to content

Commit 65abd71

Browse files
make trendlines more robust
1 parent fd3b741 commit 65abd71

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,24 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
277277
attr_value in ["ols", "lowess"]
278278
and args["x"]
279279
and args["y"]
280-
and len(trace_data) > 1
280+
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
281281
):
282282
import statsmodels.api as sm
283283

284284
# sorting is bad but trace_specs with "trendline" have no other attrs
285285
sorted_trace_data = trace_data.sort_values(by=args["x"])
286-
y = sorted_trace_data[args["y"]]
287-
x = sorted_trace_data[args["x"]]
286+
y = sorted_trace_data[args["y"]].values
287+
x = sorted_trace_data[args["x"]].values
288288

289+
x_is_date = False
289290
if x.dtype.type == np.datetime64:
290291
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
292+
x_is_date = True
293+
elif x.dtype.type == np.object_:
294+
x = x.astype(np.float64)
295+
296+
if y.dtype.type == np.object_:
297+
y = y.astype(np.float64)
291298

292299
if attr_value == "lowess":
293300
# missing ='drop' is the default value for lowess but not for OLS (None)
@@ -298,25 +305,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
298305
hover_header = "<b>LOWESS trendline</b><br><br>"
299306
elif attr_value == "ols":
300307
fit_results = sm.OLS(
301-
y.values, sm.add_constant(x.values), missing="drop"
308+
y, sm.add_constant(x), missing="drop"
302309
).fit()
303310
trace_patch["y"] = fit_results.predict()
304311
trace_patch["x"] = x[
305312
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
306313
]
307314
hover_header = "<b>OLS trendline</b><br>"
308-
hover_header += "%s = %g * %s + %g<br>" % (
309-
args["y"],
310-
fit_results.params[1],
311-
args["x"],
312-
fit_results.params[0],
313-
)
315+
if len(fit_results.params) == 2:
316+
hover_header += "%s = %g * %s + %g<br>" % (
317+
args["y"],
318+
fit_results.params[1],
319+
args["x"],
320+
fit_results.params[0],
321+
)
322+
else:
323+
hover_header += "%s = %g<br>" % (
324+
args["y"],
325+
fit_results.params[0],
326+
)
314327
hover_header += (
315328
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
316329
)
330+
if x_is_date:
331+
trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9)
317332
mapping_labels[get_label(args, args["x"])] = "%{x}"
318333
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
319-
320334
elif attr_name.startswith("error"):
321335
error_xy = attr_name[:7]
322336
arr = "arrayminus" if attr_name.endswith("minus") else "array"
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,90 @@
11
import plotly.express as px
22
import numpy as np
3+
import pandas as pd
4+
import pytest
5+
from datetime import datetime
36

47

5-
def test_trendline_nan_values():
8+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
9+
def test_trendline_results_passthrough(mode):
10+
df = px.data.gapminder().query("continent == 'Oceania'")
11+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12+
assert len(fig.data) == 4
13+
for trace in fig["data"][0::2]:
14+
assert "trendline" not in trace.hovertemplate
15+
for trendline in fig["data"][1::2]:
16+
assert "trendline" in trendline.hovertemplate
17+
if mode == "ols":
18+
assert "R<sup>2</sup>" in trendline.hovertemplate
19+
results = px.get_trendline_results(fig)
20+
if mode == "ols":
21+
assert len(results) == 2
22+
assert results["country"].values[0] == "Australia"
23+
assert results["country"].values[0] == "Australia"
24+
au_result = results["px_fit_results"].values[0]
25+
assert len(au_result.params) == 2
26+
else:
27+
assert len(results) == 0
28+
29+
30+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
31+
def test_trendline_enough_values(mode):
32+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
33+
assert len(fig.data) == 2
34+
assert len(fig.data[1].x) == 2
35+
fig = px.scatter(x=[0], y=[0], trendline=mode)
36+
assert len(fig.data) == 2
37+
assert fig.data[1].x is None
38+
fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
39+
assert len(fig.data) == 2
40+
assert fig.data[1].x is None
41+
fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
42+
assert len(fig.data) == 2
43+
assert fig.data[1].x is None
44+
fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
45+
assert len(fig.data) == 2
46+
assert len(fig.data[1].x) == 2
47+
48+
49+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
50+
def test_trendline_nan_values(mode):
651
df = px.data.gapminder().query("continent == 'Oceania'")
752
start_date = 1970
853
df["pop"][df["year"] < start_date] = np.nan
9-
modes = ["ols", "lowess"]
10-
for mode in modes:
11-
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12-
for trendline in fig["data"][1::2]:
13-
assert trendline.x[0] >= start_date
14-
assert len(trendline.x) == len(trendline.y)
54+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
55+
for trendline in fig["data"][1::2]:
56+
assert trendline.x[0] >= start_date
57+
assert len(trendline.x) == len(trendline.y)
58+
59+
60+
def test_no_slope_ols_trendline():
61+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
62+
assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
63+
results = px.get_trendline_results(fig)
64+
params = results["px_fit_results"].iloc[0].params
65+
assert np.all(np.isclose(params, [0, 1]))
66+
67+
fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
68+
assert "y = 0" in fig.data[1].hovertemplate
69+
results = px.get_trendline_results(fig)
70+
params = results["px_fit_results"].iloc[0].params
71+
assert np.all(np.isclose(params, [0]))
72+
73+
fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
74+
assert "y = 0" in fig.data[1].hovertemplate
75+
fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
76+
assert "y = 0 * x + 1" in fig.data[1].hovertemplate
77+
fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
78+
assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate
79+
80+
81+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
82+
def test_trendline_on_timeseries(mode):
83+
df = px.data.stocks()
84+
df["date"] = pd.to_datetime(df["date"])
85+
fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
86+
assert len(fig.data) == 2
87+
assert len(fig.data[0].x) == len(fig.data[1].x)
88+
assert type(fig.data[0].x[0]) == datetime
89+
assert type(fig.data[1].x[0]) == datetime
90+
assert np.all(fig.data[0].x == fig.data[1].x)

0 commit comments

Comments
 (0)