Skip to content

Commit 8285d29

Browse files
tests for new trendlines
1 parent bf15938 commit 8285d29

File tree

3 files changed

+127
-38
lines changed

3 files changed

+127
-38
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import plotly.io as pio
33
from collections import namedtuple, OrderedDict
44
from ._special_inputs import IdentityMap, Constant, Range
5-
from .trendline_functions import ols, lowess, ma, ewm
5+
from .trendline_functions import ols, lowess, ma, ewma
66

77
from _plotly_utils.basevalidators import ColorscaleValidator
88
from plotly.colors import qualitative, sequential
@@ -286,7 +286,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
286286
if trace_spec.constructor == go.Histogram:
287287
mapping_labels["count"] = "%{x}"
288288
elif attr_name == "trendline":
289-
trendline_functions = dict(lowess=lowess, ma=ma, ewm=ewm, ols=ols)
289+
trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols)
290290
if (
291291
attr_value in trendline_functions
292292
and args["x"]
@@ -326,6 +326,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
326326
trendline_function = trendline_functions[attr_value]
327327
y_out, hover_header, fit_results = trendline_function(
328328
args["trendline_options"],
329+
sorted_trace_data[args["x"]],
329330
x,
330331
y,
331332
args["x"],

Diff for: packages/python/plotly/plotly/express/trendline_functions.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33

44

5-
def ols(options, x, y, x_label, y_label, non_missing):
5+
def ols(options, x_raw, x, y, x_label, y_label, non_missing):
66
import statsmodels.api as sm
77

88
add_constant = options.get("add_constant", True)
@@ -30,14 +30,14 @@ def ols(options, x, y, x_label, y_label, non_missing):
3030
fit_results.params[0],
3131
)
3232
elif not add_constant:
33-
hover_header += "%s = %g* %s<br>" % (y_label, fit_results.params[0], x_label,)
33+
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label,)
3434
else:
3535
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
3636
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
3737
return y_out, hover_header, fit_results
3838

3939

40-
def lowess(options, x, y, x_label, y_label, non_missing):
40+
def lowess(options, x_raw, x, y, x_label, y_label, non_missing):
4141
import statsmodels.api as sm
4242

4343
frac = options.get("frac", 0.6666666)
@@ -46,13 +46,13 @@ def lowess(options, x, y, x_label, y_label, non_missing):
4646
return y_out, hover_header, None
4747

4848

49-
def ma(options, x, y, x_label, y_label, non_missing):
50-
y_out = pd.Series(y, index=x).rolling(**options).mean()[non_missing]
51-
hover_header = "<b>Moving Average trendline</b><br><br>"
49+
def ma(options, x_raw, x, y, x_label, y_label, non_missing):
50+
y_out = pd.Series(y, index=x_raw).rolling(**options).mean()[non_missing]
51+
hover_header = "<b>MA trendline</b><br><br>"
5252
return y_out, hover_header, None
5353

5454

55-
def ewm(options, x, y, x_label, y_label, non_missing):
56-
y_out = pd.Series(y, index=x).ewm(**options).mean()[non_missing]
57-
hover_header = "<b>EWM trendline</b><br><br>"
55+
def ewma(options, x_raw, x, y, x_label, y_label, non_missing):
56+
y_out = pd.Series(y, index=x_raw).ewm(**options).mean()[non_missing]
57+
hover_header = "<b>EWMA trendline</b><br><br>"
5858
return y_out, hover_header, None

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py

+115-27
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,27 @@
55
from datetime import datetime
66

77

8-
@pytest.mark.parametrize("mode", ["ols", "lowess"])
9-
def test_trendline_results_passthrough(mode):
8+
@pytest.mark.parametrize(
9+
"mode,options",
10+
[
11+
("ols", None),
12+
("ols", dict(log_x=True, log_y=True)),
13+
("lowess", None),
14+
("lowess", dict(frac=0.3)),
15+
("ma", dict(window=2)),
16+
("ewma", dict(alpha=0.5)),
17+
],
18+
)
19+
def test_trendline_results_passthrough(mode, options):
1020
df = px.data.gapminder().query("continent == 'Oceania'")
11-
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
21+
fig = px.scatter(
22+
df,
23+
x="year",
24+
y="pop",
25+
color="country",
26+
trendline=mode,
27+
trendline_options=options,
28+
)
1229
assert len(fig.data) == 4
1330
for trace in fig["data"][0::2]:
1431
assert "trendline" not in trace.hovertemplate
@@ -20,90 +37,161 @@ def test_trendline_results_passthrough(mode):
2037
if mode == "ols":
2138
assert len(results) == 2
2239
assert results["country"].values[0] == "Australia"
23-
assert results["country"].values[0] == "Australia"
2440
au_result = results["px_fit_results"].values[0]
2541
assert len(au_result.params) == 2
2642
else:
2743
assert len(results) == 0
2844

2945

30-
@pytest.mark.parametrize("mode", ["ols", "lowess"])
31-
def test_trendline_enough_values(mode):
32-
fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
46+
@pytest.mark.parametrize(
47+
"mode,options",
48+
[
49+
("ols", None),
50+
("ols", dict(add_constant=False, log_x=True, log_y=True)),
51+
("lowess", None),
52+
("lowess", dict(frac=0.3)),
53+
("ma", dict(window=2)),
54+
("ewma", dict(alpha=0.5)),
55+
],
56+
)
57+
def test_trendline_enough_values(mode, options):
58+
fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode, trendline_options=options)
3359
assert len(fig.data) == 2
3460
assert len(fig.data[1].x) == 2
35-
fig = px.scatter(x=[0], y=[0], trendline=mode)
61+
fig = px.scatter(x=[0], y=[0], trendline=mode, trendline_options=options)
3662
assert len(fig.data) == 2
3763
assert fig.data[1].x is None
38-
fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
64+
fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode, trendline_options=options)
3965
assert len(fig.data) == 2
4066
assert fig.data[1].x is None
41-
fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode)
67+
fig = px.scatter(
68+
x=[0, 1], y=np.array([0, np.nan]), trendline=mode, trendline_options=options
69+
)
4270
assert len(fig.data) == 2
4371
assert fig.data[1].x is None
44-
fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
72+
fig = px.scatter(
73+
x=[0, 1, None], y=[0, None, 1], trendline=mode, trendline_options=options
74+
)
4575
assert len(fig.data) == 2
4676
assert fig.data[1].x is None
4777
fig = px.scatter(
48-
x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode
78+
x=np.array([0, 1, np.nan]),
79+
y=np.array([0, np.nan, 1]),
80+
trendline=mode,
81+
trendline_options=options,
4982
)
5083
assert len(fig.data) == 2
5184
assert fig.data[1].x is None
52-
fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
85+
fig = px.scatter(
86+
x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode, trendline_options=options
87+
)
5388
assert len(fig.data) == 2
5489
assert len(fig.data[1].x) == 2
5590
fig = px.scatter(
56-
x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode
91+
x=np.array([0, 1, np.nan, 2]),
92+
y=np.array([1, np.nan, 1, 2]),
93+
trendline=mode,
94+
trendline_options=options,
5795
)
5896
assert len(fig.data) == 2
5997
assert len(fig.data[1].x) == 2
6098

6199

62-
@pytest.mark.parametrize("mode", ["ols", "lowess"])
63-
def test_trendline_nan_values(mode):
100+
@pytest.mark.parametrize(
101+
"mode,options",
102+
[
103+
("ols", None),
104+
("ols", dict(add_constant=False, log_x=True, log_y=True)),
105+
("lowess", None),
106+
("lowess", dict(frac=0.3)),
107+
("ma", dict(window=2)),
108+
("ewma", dict(alpha=0.5)),
109+
],
110+
)
111+
def test_trendline_nan_values(mode, options):
64112
df = px.data.gapminder().query("continent == 'Oceania'")
65113
start_date = 1970
66114
df["pop"][df["year"] < start_date] = np.nan
67-
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
115+
fig = px.scatter(
116+
df,
117+
x="year",
118+
y="pop",
119+
color="country",
120+
trendline=mode,
121+
trendline_options=options,
122+
)
68123
for trendline in fig["data"][1::2]:
69124
assert trendline.x[0] >= start_date
70125
assert len(trendline.x) == len(trendline.y)
71126

72127

73-
def test_no_slope_ols_trendline():
128+
def test_ols_trendline_slopes():
74129
fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
75-
assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
130+
assert "y = 1 * x + 0<br>" in fig.data[1].hovertemplate
76131
results = px.get_trendline_results(fig)
77132
params = results["px_fit_results"].iloc[0].params
78133
assert np.all(np.isclose(params, [0, 1]))
79134

135+
fig = px.scatter(x=[0, 1], y=[1, 2], trendline="ols")
136+
assert "y = 1 * x + 1<br>" in fig.data[1].hovertemplate
137+
results = px.get_trendline_results(fig)
138+
params = results["px_fit_results"].iloc[0].params
139+
assert np.all(np.isclose(params, [1, 1]))
140+
141+
fig = px.scatter(
142+
x=[0, 1], y=[1, 2], trendline="ols", trendline_options=dict(add_constant=False)
143+
)
144+
assert "y = 2 * x<br>" in fig.data[1].hovertemplate
145+
results = px.get_trendline_results(fig)
146+
params = results["px_fit_results"].iloc[0].params
147+
assert np.all(np.isclose(params, [2]))
148+
149+
fig = px.scatter(
150+
x=[1, 1], y=[0, 0], trendline="ols", trendline_options=dict(add_constant=False)
151+
)
152+
assert "y = 0 * x<br>" in fig.data[1].hovertemplate
153+
results = px.get_trendline_results(fig)
154+
params = results["px_fit_results"].iloc[0].params
155+
assert np.all(np.isclose(params, [0]))
156+
80157
fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
81-
assert "y = 0" in fig.data[1].hovertemplate
158+
assert "y = 0<br>" in fig.data[1].hovertemplate
82159
results = px.get_trendline_results(fig)
83160
params = results["px_fit_results"].iloc[0].params
84161
assert np.all(np.isclose(params, [0]))
85162

86163
fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
87-
assert "y = 0" in fig.data[1].hovertemplate
164+
assert "y = 0 * x + 0<br>" in fig.data[1].hovertemplate
88165
fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
89-
assert "y = 0 * x + 1" in fig.data[1].hovertemplate
166+
assert "y = 0 * x + 1<br>" in fig.data[1].hovertemplate
90167
fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
91-
assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate
168+
assert "y = 0 * x + 1.5<br>" in fig.data[1].hovertemplate
92169

93170

94-
@pytest.mark.parametrize("mode", ["ols", "lowess"])
95-
def test_trendline_on_timeseries(mode):
171+
@pytest.mark.parametrize(
172+
"mode,options",
173+
[
174+
("ols", None),
175+
("ols", dict(add_constant=False, log_x=True, log_y=True)),
176+
("lowess", None),
177+
("lowess", dict(frac=0.3)),
178+
("ma", dict(window=2)),
179+
("ma", dict(window="10d")),
180+
("ewma", dict(alpha=0.5)),
181+
],
182+
)
183+
def test_trendline_on_timeseries(mode, options):
96184
df = px.data.stocks()
97185

98186
with pytest.raises(ValueError) as err_msg:
99-
px.scatter(df, x="date", y="GOOG", trendline=mode)
187+
px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options)
100188
assert "Could not convert value of 'x' ('date') into a numeric type." in str(
101189
err_msg.value
102190
)
103191

104192
df["date"] = pd.to_datetime(df["date"])
105193
df["date"] = df["date"].dt.tz_localize("CET") # force a timezone
106-
fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
194+
fig = px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options)
107195
assert len(fig.data) == 2
108196
assert len(fig.data[0].x) == len(fig.data[1].x)
109197
assert type(fig.data[0].x[0]) == datetime

0 commit comments

Comments
 (0)