Skip to content

Commit b6806ff

Browse files
trendline_scope
1 parent e7a2fbc commit b6806ff

File tree

4 files changed

+100
-18
lines changed

4 files changed

+100
-18
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def scatter(
4848
trendline=None,
4949
trendline_options=None,
5050
trendline_color_override=None,
51+
trendline_scope="trace",
5152
log_x=False,
5253
log_y=False,
5354
range_x=None,
@@ -93,6 +94,7 @@ def density_contour(
9394
trendline=None,
9495
trendline_options=None,
9596
trendline_color_override=None,
97+
trendline_scope="trace",
9698
log_x=False,
9799
log_y=False,
98100
range_x=None,
@@ -202,7 +204,9 @@ def density_heatmap(
202204
z=[
203205
"For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.",
204206
],
205-
histfunc=["The arguments to this function are the values of `z`.",],
207+
histfunc=[
208+
"The arguments to this function are the values of `z`.",
209+
],
206210
),
207211
)
208212

@@ -467,7 +471,9 @@ def histogram(
467471
args=locals(),
468472
constructor=go.Histogram,
469473
trace_patch=dict(
470-
histnorm=histnorm, histfunc=histfunc, cumulative=dict(enabled=cumulative),
474+
histnorm=histnorm,
475+
histfunc=histfunc,
476+
cumulative=dict(enabled=cumulative),
471477
),
472478
layout_patch=dict(barmode=barmode, barnorm=barnorm),
473479
)
@@ -527,7 +533,11 @@ def violin(
527533
args=locals(),
528534
constructor=go.Violin,
529535
trace_patch=dict(
530-
points=points, box=dict(visible=box), scalegroup=True, x0=" ", y0=" ",
536+
points=points,
537+
box=dict(visible=box),
538+
scalegroup=True,
539+
x0=" ",
540+
y0=" ",
531541
),
532542
layout_patch=dict(violinmode=violinmode),
533543
)

packages/python/plotly/plotly/express/_core.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
347347
)
348348

349349
# preserve original values of "x" in case they're dates
350+
# otherwise numpy/pandas can mess with the timezones
351+
# NB this means trendline functions must output one-to-one with the input series
352+
# i.e. we can't do resampling, because then the X values might not line up!
350353
non_missing = np.logical_not(
351354
np.logical_or(np.isnan(y), np.isnan(x))
352355
)
@@ -867,23 +870,25 @@ def make_trace_spec(args, constructor, attrs, trace_patch):
867870
result.append(trace_spec)
868871

869872
# Add trendline trace specifications
870-
if "trendline" in args and args["trendline"]:
871-
trace_spec = TraceSpec(
872-
constructor=go.Scattergl
873-
if constructor == go.Scattergl # could be contour
874-
else go.Scatter,
875-
attrs=["trendline"],
876-
trace_patch=dict(mode="lines"),
877-
marginal=None,
878-
)
879-
if args["trendline_color_override"]:
880-
trace_spec.trace_patch["line"] = dict(
881-
color=args["trendline_color_override"]
882-
)
883-
result.append(trace_spec)
873+
if args.get("trendline") and args.get("trendline_scope", "trace") == "trace":
874+
result.append(make_trendline_spec(args, constructor))
884875
return result
885876

886877

878+
def make_trendline_spec(args, constructor):
879+
trace_spec = TraceSpec(
880+
constructor=go.Scattergl
881+
if constructor == go.Scattergl # could be contour
882+
else go.Scatter,
883+
attrs=["trendline"],
884+
trace_patch=dict(mode="lines"),
885+
marginal=None,
886+
)
887+
if args["trendline_color_override"]:
888+
trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"])
889+
return trace_spec
890+
891+
887892
def one_group(x):
888893
return ""
889894

@@ -2127,6 +2132,27 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
21272132
fig.update_layout(template=args["template"], overwrite=True)
21282133
fig.frames = frame_list if len(frames) > 1 else []
21292134

2135+
if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":
2136+
trendline_spec = make_trendline_spec(args, constructor)
2137+
trendline_trace = trendline_spec.constructor(
2138+
name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False
2139+
)
2140+
if "line" not in trendline_spec.trace_patch: # no color override
2141+
for m in grouped_mappings:
2142+
if m.variable == "color":
2143+
next_color = m.sequence[len(m.val_map) % len(m.sequence)]
2144+
trendline_spec.trace_patch["line"] = dict(color=next_color)
2145+
patch, fit_results = make_trace_kwargs(
2146+
args, trendline_spec, args["data_frame"], {}, sizeref
2147+
)
2148+
trendline_trace.update(patch)
2149+
fig.add_trace(
2150+
trendline_trace, row="all", col="all", exclude_empty_subplots=True
2151+
)
2152+
fig.update_traces(selector=-1, showlegend=True)
2153+
if fit_results is not None:
2154+
trendline_rows.append(dict(px_fit_results=fit_results))
2155+
21302156
fig._px_trendlines = pd.DataFrame(trendline_rows)
21312157

21322158
configure_axes(args, constructor, fig, orders)

packages/python/plotly/plotly/express/_doc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,10 @@
325325
"Setting this value is recommended when using `plotly.express.colors.diverging` color scales as the inputs to `color_continuous_scale`.",
326326
],
327327
size_max=["int (default `20`)", "Set the maximum mark size when using `size`."],
328-
markers=["boolean (default `False`)", "If `True`, markers are shown on lines.",],
328+
markers=[
329+
"boolean (default `False`)",
330+
"If `True`, markers are shown on lines.",
331+
],
329332
log_x=[
330333
"boolean (default `False`)",
331334
"If `True`, the x-axis is log-scaled in cartesian coordinates.",
@@ -420,6 +423,10 @@
420423
"Valid CSS color.",
421424
"If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.",
422425
],
426+
trendline_scope=[
427+
"str (one of `'trace'` or `'overall'`, default `'trace'`)",
428+
"If `'trace'`, then one trendline is drawn per trace (i.e. per color, symbol, facet, animation frame etc) and if `'overall'` then one trendline is computed for the entire dataset, and replicated across all facets.",
429+
],
423430
render_mode=[
424431
"str",
425432
"One of `'auto'`, `'svg'` or `'webgl'`, default `'auto'`",

packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,42 @@ def test_trendline_on_timeseries(mode, options):
200200
assert type(fig.data[1].x[0]) == datetime
201201
assert np.all(fig.data[0].x == fig.data[1].x)
202202
assert str(fig.data[0].x[0]) == str(fig.data[1].x[0])
203+
204+
205+
def test_overall_trendline():
206+
df = px.data.tips()
207+
fig1 = px.scatter(df, x="total_bill", y="tip", trendline="ols")
208+
assert len(fig1.data) == 2
209+
assert "trendline" in fig1.data[1].hovertemplate
210+
results1 = px.get_trendline_results(fig1)
211+
params1 = results1["px_fit_results"].iloc[0].params
212+
213+
fig2 = px.scatter(
214+
df,
215+
x="total_bill",
216+
y="tip",
217+
color="sex",
218+
trendline="ols",
219+
trendline_scope="overall",
220+
)
221+
assert len(fig2.data) == 3
222+
assert "trendline" in fig2.data[2].hovertemplate
223+
results2 = px.get_trendline_results(fig2)
224+
params2 = results2["px_fit_results"].iloc[0].params
225+
226+
assert np.all(np.array_equal(params1, params2))
227+
228+
fig3 = px.scatter(
229+
df,
230+
x="total_bill",
231+
y="tip",
232+
facet_row="sex",
233+
trendline="ols",
234+
trendline_scope="overall",
235+
)
236+
assert len(fig3.data) == 4
237+
assert "trendline" in fig3.data[3].hovertemplate
238+
results3 = px.get_trendline_results(fig3)
239+
params3 = results3["px_fit_results"].iloc[0].params
240+
241+
assert np.all(np.array_equal(params1, params3))

0 commit comments

Comments
 (0)