Skip to content

Commit 8be4915

Browse files
Merge pull request #2554 from plotly/trendline_fix
make trendlines more robust
2 parents fd3b741 + e823437 commit 8be4915

File tree

3 files changed

+140
-18
lines changed

3 files changed

+140
-18
lines changed

Diff for: CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
88

99
- Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524))
1010
- Fixed bug in `hover_data` argument of `px` functions, when the column name is changed with labels and `hover_data` is a dictionary setting up a specific format for the hover data ([#2544](https://github.com/plotly/plotly.py/pull/2544)).
11+
- Made the Plotly Express `trendline` argument more robust and made it work with datetime `x` values ([#2554](https://github.com/plotly/plotly.py/pull/2554))
1112

1213
## [4.8.1] - 2020-05-28
1314

Diff for: packages/python/plotly/plotly/express/_core.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,35 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
277277
attr_value in ["ols", "lowess"]
278278
and args["x"]
279279
and args["y"]
280-
and len(trace_data) > 1
280+
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
281281
):
282282
import statsmodels.api as sm
283283

284284
# sorting is bad but trace_specs with "trendline" have no other attrs
285285
sorted_trace_data = trace_data.sort_values(by=args["x"])
286-
y = sorted_trace_data[args["y"]]
287-
x = sorted_trace_data[args["x"]]
286+
y = sorted_trace_data[args["y"]].values
287+
x = sorted_trace_data[args["x"]].values
288288

289+
x_is_date = False
289290
if x.dtype.type == np.datetime64:
290291
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
292+
x_is_date = True
293+
elif x.dtype.type == np.object_:
294+
try:
295+
x = x.astype(np.float64)
296+
except ValueError:
297+
raise ValueError(
298+
"Could not convert value of 'x' ('%s') into a numeric type. "
299+
"If 'x' contains stringified dates, please convert to a datetime column."
300+
% args["x"]
301+
)
302+
if y.dtype.type == np.object_:
303+
try:
304+
y = y.astype(np.float64)
305+
except ValueError:
306+
raise ValueError(
307+
"Could not convert value of 'y' into a numeric type."
308+
)
291309

292310
if attr_value == "lowess":
293311
# missing ='drop' is the default value for lowess but not for OLS (None)
@@ -298,25 +316,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
298316
hover_header = "<b>LOWESS trendline</b><br><br>"
299317
elif attr_value == "ols":
300318
fit_results = sm.OLS(
301-
y.values, sm.add_constant(x.values), missing="drop"
319+
y, sm.add_constant(x), missing="drop"
302320
).fit()
303321
trace_patch["y"] = fit_results.predict()
304322
trace_patch["x"] = x[
305323
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
306324
]
307325
hover_header = "<b>OLS trendline</b><br>"
308-
hover_header += "%s = %g * %s + %g<br>" % (
309-
args["y"],
310-
fit_results.params[1],
311-
args["x"],
312-
fit_results.params[0],
313-
)
326+
if len(fit_results.params) == 2:
327+
hover_header += "%s = %g * %s + %g<br>" % (
328+
args["y"],
329+
fit_results.params[1],
330+
args["x"],
331+
fit_results.params[0],
332+
)
333+
else:
334+
hover_header += "%s = %g<br>" % (
335+
args["y"],
336+
fit_results.params[0],
337+
)
314338
hover_header += (
315339
"R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
316340
)
341+
if x_is_date:
342+
trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9)
317343
mapping_labels[get_label(args, args["x"])] = "%{x}"
318344
mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
319-
320345
elif attr_name.startswith("error"):
321346
error_xy = attr_name[:7]
322347
arr = "arrayminus" if attr_name.endswith("minus") else "array"
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,110 @@
11
import plotly.express as px
22
import numpy as np
3+
import pandas as pd
4+
import pytest
5+
from datetime import datetime
36

47

5-
def test_trendline_nan_values():
8+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
9+
def test_trendline_results_passthrough(mode):
10+
df = px.data.gapminder().query("continent == 'Oceania'")
11+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12+
assert len(fig.data) == 4
13+
for trace in fig["data"][0::2]:
14+
assert "trendline" not in trace.hovertemplate
15+
for trendline in fig["data"][1::2]:
16+
assert "trendline" in trendline.hovertemplate
17+
if mode == "ols":
18+
assert "R<sup>2</sup>" in trendline.hovertemplate
19+
results = px.get_trendline_results(fig)
20+
if mode == "ols":
21+
assert len(results) == 2
22+
assert results["country"].values[0] == "Australia"
23+
assert results["country"].values[0] == "Australia"
24+
au_result = results["px_fit_results"].values[0]
25+
assert len(au_result.params) == 2
26+
else:
27+
assert len(results) == 0
28+
29+
30+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
31+
def test_trendline_enough_values(mode):
32+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
33+
assert len(fig.data) == 2
34+
assert len(fig.data[1].x) == 2
35+
fig = px.scatter(x=[0], y=[0], trendline=mode)
36+
assert len(fig.data) == 2
37+
assert fig.data[1].x is None
38+
fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
39+
assert len(fig.data) == 2
40+
assert fig.data[1].x is None
41+
fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode)
42+
assert len(fig.data) == 2
43+
assert fig.data[1].x is None
44+
fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
45+
assert len(fig.data) == 2
46+
assert fig.data[1].x is None
47+
fig = px.scatter(
48+
x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode
49+
)
50+
assert len(fig.data) == 2
51+
assert fig.data[1].x is None
52+
fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
53+
assert len(fig.data) == 2
54+
assert len(fig.data[1].x) == 2
55+
fig = px.scatter(
56+
x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode
57+
)
58+
assert len(fig.data) == 2
59+
assert len(fig.data[1].x) == 2
60+
61+
62+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
63+
def test_trendline_nan_values(mode):
664
df = px.data.gapminder().query("continent == 'Oceania'")
765
start_date = 1970
866
df["pop"][df["year"] < start_date] = np.nan
9-
modes = ["ols", "lowess"]
10-
for mode in modes:
11-
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
12-
for trendline in fig["data"][1::2]:
13-
assert trendline.x[0] >= start_date
14-
assert len(trendline.x) == len(trendline.y)
67+
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
68+
for trendline in fig["data"][1::2]:
69+
assert trendline.x[0] >= start_date
70+
assert len(trendline.x) == len(trendline.y)
71+
72+
73+
def test_no_slope_ols_trendline():
74+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
75+
assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
76+
results = px.get_trendline_results(fig)
77+
params = results["px_fit_results"].iloc[0].params
78+
assert np.all(np.isclose(params, [0, 1]))
79+
80+
fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
81+
assert "y = 0" in fig.data[1].hovertemplate
82+
results = px.get_trendline_results(fig)
83+
params = results["px_fit_results"].iloc[0].params
84+
assert np.all(np.isclose(params, [0]))
85+
86+
fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
87+
assert "y = 0" in fig.data[1].hovertemplate
88+
fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
89+
assert "y = 0 * x + 1" in fig.data[1].hovertemplate
90+
fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
91+
assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate
92+
93+
94+
@pytest.mark.parametrize("mode", ["ols", "lowess"])
95+
def test_trendline_on_timeseries(mode):
96+
df = px.data.stocks()
97+
98+
with pytest.raises(ValueError) as err_msg:
99+
px.scatter(df, x="date", y="GOOG", trendline=mode)
100+
assert "Could not convert value of 'x' ('date') into a numeric type." in str(
101+
err_msg.value
102+
)
103+
104+
df["date"] = pd.to_datetime(df["date"])
105+
fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
106+
assert len(fig.data) == 2
107+
assert len(fig.data[0].x) == len(fig.data[1].x)
108+
assert type(fig.data[0].x[0]) == datetime
109+
assert type(fig.data[1].x[0]) == datetime
110+
assert np.all(fig.data[0].x == fig.data[1].x)

0 commit comments

Comments
 (0)