5
5
from datetime import datetime
6
6
7
7
8
- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
9
- def test_trendline_results_passthrough (mode ):
8
+ @pytest .mark .parametrize (
9
+ "mode,options" ,
10
+ [
11
+ ("ols" , None ),
12
+ ("ols" , dict (log_x = True , log_y = True )),
13
+ ("lowess" , None ),
14
+ ("lowess" , dict (frac = 0.3 )),
15
+ ("ma" , dict (window = 2 )),
16
+ ("ewma" , dict (alpha = 0.5 )),
17
+ ],
18
+ )
19
+ def test_trendline_results_passthrough (mode , options ):
10
20
df = px .data .gapminder ().query ("continent == 'Oceania'" )
11
- fig = px .scatter (df , x = "year" , y = "pop" , color = "country" , trendline = mode )
21
+ fig = px .scatter (
22
+ df ,
23
+ x = "year" ,
24
+ y = "pop" ,
25
+ color = "country" ,
26
+ trendline = mode ,
27
+ trendline_options = options ,
28
+ )
12
29
assert len (fig .data ) == 4
13
30
for trace in fig ["data" ][0 ::2 ]:
14
31
assert "trendline" not in trace .hovertemplate
@@ -20,90 +37,161 @@ def test_trendline_results_passthrough(mode):
20
37
if mode == "ols" :
21
38
assert len (results ) == 2
22
39
assert results ["country" ].values [0 ] == "Australia"
23
- assert results ["country" ].values [0 ] == "Australia"
24
40
au_result = results ["px_fit_results" ].values [0 ]
25
41
assert len (au_result .params ) == 2
26
42
else :
27
43
assert len (results ) == 0
28
44
29
45
30
- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
31
- def test_trendline_enough_values (mode ):
32
- fig = px .scatter (x = [0 , 1 ], y = [0 , 1 ], trendline = mode )
46
+ @pytest .mark .parametrize (
47
+ "mode,options" ,
48
+ [
49
+ ("ols" , None ),
50
+ ("ols" , dict (add_constant = False , log_x = True , log_y = True )),
51
+ ("lowess" , None ),
52
+ ("lowess" , dict (frac = 0.3 )),
53
+ ("ma" , dict (window = 2 )),
54
+ ("ewma" , dict (alpha = 0.5 )),
55
+ ],
56
+ )
57
+ def test_trendline_enough_values (mode , options ):
58
+ fig = px .scatter (x = [0 , 1 ], y = [0 , 1 ], trendline = mode , trendline_options = options )
33
59
assert len (fig .data ) == 2
34
60
assert len (fig .data [1 ].x ) == 2
35
- fig = px .scatter (x = [0 ], y = [0 ], trendline = mode )
61
+ fig = px .scatter (x = [0 ], y = [0 ], trendline = mode , trendline_options = options )
36
62
assert len (fig .data ) == 2
37
63
assert fig .data [1 ].x is None
38
- fig = px .scatter (x = [0 , 1 ], y = [0 , None ], trendline = mode )
64
+ fig = px .scatter (x = [0 , 1 ], y = [0 , None ], trendline = mode , trendline_options = options )
39
65
assert len (fig .data ) == 2
40
66
assert fig .data [1 ].x is None
41
- fig = px .scatter (x = [0 , 1 ], y = np .array ([0 , np .nan ]), trendline = mode )
67
+ fig = px .scatter (
68
+ x = [0 , 1 ], y = np .array ([0 , np .nan ]), trendline = mode , trendline_options = options
69
+ )
42
70
assert len (fig .data ) == 2
43
71
assert fig .data [1 ].x is None
44
- fig = px .scatter (x = [0 , 1 , None ], y = [0 , None , 1 ], trendline = mode )
72
+ fig = px .scatter (
73
+ x = [0 , 1 , None ], y = [0 , None , 1 ], trendline = mode , trendline_options = options
74
+ )
45
75
assert len (fig .data ) == 2
46
76
assert fig .data [1 ].x is None
47
77
fig = px .scatter (
48
- x = np .array ([0 , 1 , np .nan ]), y = np .array ([0 , np .nan , 1 ]), trendline = mode
78
+ x = np .array ([0 , 1 , np .nan ]),
79
+ y = np .array ([0 , np .nan , 1 ]),
80
+ trendline = mode ,
81
+ trendline_options = options ,
49
82
)
50
83
assert len (fig .data ) == 2
51
84
assert fig .data [1 ].x is None
52
- fig = px .scatter (x = [0 , 1 , None , 2 ], y = [1 , None , 1 , 2 ], trendline = mode )
85
+ fig = px .scatter (
86
+ x = [0 , 1 , None , 2 ], y = [1 , None , 1 , 2 ], trendline = mode , trendline_options = options
87
+ )
53
88
assert len (fig .data ) == 2
54
89
assert len (fig .data [1 ].x ) == 2
55
90
fig = px .scatter (
56
- x = np .array ([0 , 1 , np .nan , 2 ]), y = np .array ([1 , np .nan , 1 , 2 ]), trendline = mode
91
+ x = np .array ([0 , 1 , np .nan , 2 ]),
92
+ y = np .array ([1 , np .nan , 1 , 2 ]),
93
+ trendline = mode ,
94
+ trendline_options = options ,
57
95
)
58
96
assert len (fig .data ) == 2
59
97
assert len (fig .data [1 ].x ) == 2
60
98
61
99
62
- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
63
- def test_trendline_nan_values (mode ):
100
+ @pytest .mark .parametrize (
101
+ "mode,options" ,
102
+ [
103
+ ("ols" , None ),
104
+ ("ols" , dict (add_constant = False , log_x = True , log_y = True )),
105
+ ("lowess" , None ),
106
+ ("lowess" , dict (frac = 0.3 )),
107
+ ("ma" , dict (window = 2 )),
108
+ ("ewma" , dict (alpha = 0.5 )),
109
+ ],
110
+ )
111
+ def test_trendline_nan_values (mode , options ):
64
112
df = px .data .gapminder ().query ("continent == 'Oceania'" )
65
113
start_date = 1970
66
114
df ["pop" ][df ["year" ] < start_date ] = np .nan
67
- fig = px .scatter (df , x = "year" , y = "pop" , color = "country" , trendline = mode )
115
+ fig = px .scatter (
116
+ df ,
117
+ x = "year" ,
118
+ y = "pop" ,
119
+ color = "country" ,
120
+ trendline = mode ,
121
+ trendline_options = options ,
122
+ )
68
123
for trendline in fig ["data" ][1 ::2 ]:
69
124
assert trendline .x [0 ] >= start_date
70
125
assert len (trendline .x ) == len (trendline .y )
71
126
72
127
73
- def test_no_slope_ols_trendline ():
128
+ def test_ols_trendline_slopes ():
74
129
fig = px .scatter (x = [0 , 1 ], y = [0 , 1 ], trendline = "ols" )
75
- assert "y = 1" in fig .data [1 ].hovertemplate # then + x*(some small number)
130
+ assert "y = 1 * x + 0<br> " in fig .data [1 ].hovertemplate
76
131
results = px .get_trendline_results (fig )
77
132
params = results ["px_fit_results" ].iloc [0 ].params
78
133
assert np .all (np .isclose (params , [0 , 1 ]))
79
134
135
+ fig = px .scatter (x = [0 , 1 ], y = [1 , 2 ], trendline = "ols" )
136
+ assert "y = 1 * x + 1<br>" in fig .data [1 ].hovertemplate
137
+ results = px .get_trendline_results (fig )
138
+ params = results ["px_fit_results" ].iloc [0 ].params
139
+ assert np .all (np .isclose (params , [1 , 1 ]))
140
+
141
+ fig = px .scatter (
142
+ x = [0 , 1 ], y = [1 , 2 ], trendline = "ols" , trendline_options = dict (add_constant = False )
143
+ )
144
+ assert "y = 2 * x<br>" in fig .data [1 ].hovertemplate
145
+ results = px .get_trendline_results (fig )
146
+ params = results ["px_fit_results" ].iloc [0 ].params
147
+ assert np .all (np .isclose (params , [2 ]))
148
+
149
+ fig = px .scatter (
150
+ x = [1 , 1 ], y = [0 , 0 ], trendline = "ols" , trendline_options = dict (add_constant = False )
151
+ )
152
+ assert "y = 0 * x<br>" in fig .data [1 ].hovertemplate
153
+ results = px .get_trendline_results (fig )
154
+ params = results ["px_fit_results" ].iloc [0 ].params
155
+ assert np .all (np .isclose (params , [0 ]))
156
+
80
157
fig = px .scatter (x = [1 , 1 ], y = [0 , 0 ], trendline = "ols" )
81
- assert "y = 0" in fig .data [1 ].hovertemplate
158
+ assert "y = 0<br> " in fig .data [1 ].hovertemplate
82
159
results = px .get_trendline_results (fig )
83
160
params = results ["px_fit_results" ].iloc [0 ].params
84
161
assert np .all (np .isclose (params , [0 ]))
85
162
86
163
fig = px .scatter (x = [1 , 2 ], y = [0 , 0 ], trendline = "ols" )
87
- assert "y = 0" in fig .data [1 ].hovertemplate
164
+ assert "y = 0 * x + 0<br> " in fig .data [1 ].hovertemplate
88
165
fig = px .scatter (x = [0 , 0 ], y = [1 , 1 ], trendline = "ols" )
89
- assert "y = 0 * x + 1" in fig .data [1 ].hovertemplate
166
+ assert "y = 0 * x + 1<br> " in fig .data [1 ].hovertemplate
90
167
fig = px .scatter (x = [0 , 0 ], y = [1 , 2 ], trendline = "ols" )
91
- assert "y = 0 * x + 1.5" in fig .data [1 ].hovertemplate
168
+ assert "y = 0 * x + 1.5<br> " in fig .data [1 ].hovertemplate
92
169
93
170
94
- @pytest .mark .parametrize ("mode" , ["ols" , "lowess" ])
95
- def test_trendline_on_timeseries (mode ):
171
+ @pytest .mark .parametrize (
172
+ "mode,options" ,
173
+ [
174
+ ("ols" , None ),
175
+ ("ols" , dict (add_constant = False , log_x = True , log_y = True )),
176
+ ("lowess" , None ),
177
+ ("lowess" , dict (frac = 0.3 )),
178
+ ("ma" , dict (window = 2 )),
179
+ ("ma" , dict (window = "10d" )),
180
+ ("ewma" , dict (alpha = 0.5 )),
181
+ ],
182
+ )
183
+ def test_trendline_on_timeseries (mode , options ):
96
184
df = px .data .stocks ()
97
185
98
186
with pytest .raises (ValueError ) as err_msg :
99
- px .scatter (df , x = "date" , y = "GOOG" , trendline = mode )
187
+ px .scatter (df , x = "date" , y = "GOOG" , trendline = mode , trendline_options = options )
100
188
assert "Could not convert value of 'x' ('date') into a numeric type." in str (
101
189
err_msg .value
102
190
)
103
191
104
192
df ["date" ] = pd .to_datetime (df ["date" ])
105
193
df ["date" ] = df ["date" ].dt .tz_localize ("CET" ) # force a timezone
106
- fig = px .scatter (df , x = "date" , y = "GOOG" , trendline = mode )
194
+ fig = px .scatter (df , x = "date" , y = "GOOG" , trendline = mode , trendline_options = options )
107
195
assert len (fig .data ) == 2
108
196
assert len (fig .data [0 ].x ) == len (fig .data [1 ].x )
109
197
assert type (fig .data [0 ].x [0 ]) == datetime
0 commit comments