Skip to content

Commit 5830055

Browse files
Merge pull request #1875 from plotly/px_real_template
PX shouldn't modify attrs controlled by template
2 parents 06a2cb9 + 1e88daa commit 5830055

File tree

2 files changed

+147
-40
lines changed

2 files changed

+147
-40
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+40-40
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders):
375375

376376
# Configure axis ticks on marginal subplots
377377
if args["marginal_x"]:
378-
fig.update_yaxes(
379-
showticklabels=False, showgrid=args["marginal_x"] == "histogram", row=nrows
380-
)
381-
fig.update_xaxes(showgrid=True, row=nrows)
378+
fig.update_yaxes(showticklabels=False, row=nrows)
379+
if args["template"].layout.yaxis.showgrid is None:
380+
fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
381+
if args["template"].layout.xaxis.showgrid is None:
382+
fig.update_xaxes(showgrid=True, row=nrows)
382383

383384
if args["marginal_y"]:
384-
fig.update_xaxes(
385-
showticklabels=False, showgrid=args["marginal_y"] == "histogram", col=ncols
386-
)
387-
fig.update_yaxes(showgrid=True, col=ncols)
385+
fig.update_xaxes(showticklabels=False, col=ncols)
386+
if args["template"].layout.xaxis.showgrid is None:
387+
fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
388+
if args["template"].layout.yaxis.showgrid is None:
389+
fig.update_yaxes(showgrid=True, col=ncols)
388390

389391
# Add axis titles to non-marginal subplots
390392
y_title = get_decorated_label(args, args["y"], "y")
@@ -687,55 +689,47 @@ def apply_default_cascade(args):
687689
else:
688690
args["template"] = "plotly"
689691

690-
# retrieve the actual template if we were given a name
691692
try:
692-
template = pio.templates[args["template"]]
693+
# retrieve the actual template if we were given a name
694+
args["template"] = pio.templates[args["template"]]
693695
except Exception:
694-
template = args["template"]
696+
# otherwise try to build a real template
697+
args["template"] = go.layout.Template(args["template"])
695698

696699
# if colors not set explicitly or in px.defaults, defer to a template
697700
# if the template doesn't have one, we set some final fallback defaults
698701
if "color_continuous_scale" in args:
699-
if args["color_continuous_scale"] is None:
700-
try:
701-
args["color_continuous_scale"] = [
702-
x[1] for x in template.layout.colorscale.sequential
703-
]
704-
except (AttributeError, TypeError):
705-
pass
702+
if (
703+
args["color_continuous_scale"] is None
704+
and args["template"].layout.colorscale.sequential
705+
):
706+
args["color_continuous_scale"] = [
707+
x[1] for x in args["template"].layout.colorscale.sequential
708+
]
706709
if args["color_continuous_scale"] is None:
707710
args["color_continuous_scale"] = sequential.Viridis
708711

709712
if "color_discrete_sequence" in args:
710-
if args["color_discrete_sequence"] is None:
711-
try:
712-
args["color_discrete_sequence"] = template.layout.colorway
713-
except (AttributeError, TypeError):
714-
pass
713+
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
714+
args["color_discrete_sequence"] = args["template"].layout.colorway
715715
if args["color_discrete_sequence"] is None:
716716
args["color_discrete_sequence"] = qualitative.D3
717717

718718
# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
719719
# see if we can defer to template. If not, set reasonable defaults
720720
if "symbol_sequence" in args:
721-
if args["symbol_sequence"] is None:
722-
try:
723-
args["symbol_sequence"] = [
724-
scatter.marker.symbol for scatter in template.data.scatter
725-
]
726-
except (AttributeError, TypeError):
727-
pass
721+
if args["symbol_sequence"] is None and args["template"].data.scatter:
722+
args["symbol_sequence"] = [
723+
scatter.marker.symbol for scatter in args["template"].data.scatter
724+
]
728725
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
729726
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]
730727

731728
if "line_dash_sequence" in args:
732-
if args["line_dash_sequence"] is None:
733-
try:
734-
args["line_dash_sequence"] = [
735-
scatter.line.dash for scatter in template.data.scatter
736-
]
737-
except (AttributeError, TypeError):
738-
pass
729+
if args["line_dash_sequence"] is None and args["template"].data.scatter:
730+
args["line_dash_sequence"] = [
731+
scatter.line.dash for scatter in args["template"].data.scatter
732+
]
739733
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
740734
args["line_dash_sequence"] = [
741735
"solid",
@@ -1264,13 +1258,17 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12641258
cmax=range_color[1],
12651259
colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)),
12661260
)
1267-
for v in ["title", "height", "width", "template"]:
1261+
for v in ["title", "height", "width"]:
12681262
if args[v]:
12691263
layout_patch[v] = args[v]
12701264
layout_patch["legend"] = {"tracegroupgap": 0}
1271-
if "title" not in layout_patch:
1265+
if "title" not in layout_patch and args["template"].layout.margin.t is None:
12721266
layout_patch["margin"] = {"t": 60}
1273-
if "size" in args and args["size"]:
1267+
if (
1268+
"size" in args
1269+
and args["size"]
1270+
and args["template"].layout.legend.itemsizing is None
1271+
):
12741272
layout_patch["legend"]["itemsizing"] = "constant"
12751273

12761274
fig = init_figure(
@@ -1295,6 +1293,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12951293
# Add traces, layout and frames to figure
12961294
fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
12971295
fig.layout.update(layout_patch)
1296+
if "template" in args and args["template"] is not None:
1297+
fig.update_layout(template=args["template"], overwrite=True)
12981298
fig.frames = frame_list if len(frames) > 1 else []
12991299

13001300
fig._px_trendlines = pd.DataFrame(trendline_rows)

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_px.py

+107
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,110 @@ def test_custom_data_scatter():
5151
fig.data[0].hovertemplate
5252
== "sepal_width=%{x}<br>sepal_length=%{y}<br>petal_length=%{customdata[2]}<br>petal_width=%{customdata[3]}<br>species_id=%{customdata[0]}"
5353
)
54+
55+
56+
def test_px_templates():
57+
import plotly.io as pio
58+
import plotly.graph_objects as go
59+
60+
tips = px.data.tips()
61+
62+
# use the normal defaults
63+
fig = px.scatter()
64+
assert fig.layout.template == pio.templates[pio.templates.default]
65+
66+
# respect changes to defaults
67+
pio.templates.default = "seaborn"
68+
fig = px.scatter()
69+
assert fig.layout.template == pio.templates["seaborn"]
70+
71+
# special px-level defaults over pio defaults
72+
pio.templates.default = "seaborn"
73+
px.defaults.template = "ggplot2"
74+
fig = px.scatter()
75+
assert fig.layout.template == pio.templates["ggplot2"]
76+
77+
# accept names in args over pio and px defaults
78+
fig = px.scatter(template="seaborn")
79+
assert fig.layout.template == pio.templates["seaborn"]
80+
81+
# accept objects in args
82+
fig = px.scatter(template={})
83+
assert fig.layout.template == go.layout.Template()
84+
85+
# read colorway from the template
86+
fig = px.scatter(
87+
tips,
88+
x="total_bill",
89+
y="tip",
90+
color="sex",
91+
template=dict(layout_colorway=["red", "blue"]),
92+
)
93+
assert fig.data[0].marker.color == "red"
94+
assert fig.data[1].marker.color == "blue"
95+
96+
# default colorway fallback
97+
fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict())
98+
assert fig.data[0].marker.color == px.colors.qualitative.D3[0]
99+
assert fig.data[1].marker.color == px.colors.qualitative.D3[1]
100+
101+
# pio default template colorway fallback
102+
pio.templates.default = "seaborn"
103+
px.defaults.template = None
104+
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
105+
assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0]
106+
assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1]
107+
108+
# pio default template colorway fallback
109+
pio.templates.default = "seaborn"
110+
px.defaults.template = "ggplot2"
111+
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
112+
assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0]
113+
assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1]
114+
115+
# don't overwrite top margin when set in template
116+
fig = px.scatter(title="yo")
117+
assert fig.layout.margin.t is None
118+
119+
fig = px.scatter()
120+
assert fig.layout.margin.t == 60
121+
122+
fig = px.scatter(template=dict(layout_margin_t=2))
123+
assert fig.layout.margin.t is None
124+
125+
# don't force histogram gridlines when set in template
126+
pio.templates.default = "none"
127+
px.defaults.template = None
128+
fig = px.scatter(
129+
tips, x="total_bill", y="tip", marginal_x="histogram", marginal_y="histogram"
130+
)
131+
assert fig.layout.xaxis2.showgrid
132+
assert fig.layout.xaxis3.showgrid
133+
assert fig.layout.yaxis2.showgrid
134+
assert fig.layout.yaxis3.showgrid
135+
136+
fig = px.scatter(
137+
tips,
138+
x="total_bill",
139+
y="tip",
140+
marginal_x="histogram",
141+
marginal_y="histogram",
142+
template=dict(layout_yaxis_showgrid=False),
143+
)
144+
assert fig.layout.xaxis2.showgrid
145+
assert fig.layout.xaxis3.showgrid
146+
assert fig.layout.yaxis2.showgrid is None
147+
assert fig.layout.yaxis3.showgrid is None
148+
149+
fig = px.scatter(
150+
tips,
151+
x="total_bill",
152+
y="tip",
153+
marginal_x="histogram",
154+
marginal_y="histogram",
155+
template=dict(layout_xaxis_showgrid=False),
156+
)
157+
assert fig.layout.xaxis2.showgrid is None
158+
assert fig.layout.xaxis3.showgrid is None
159+
assert fig.layout.yaxis2.showgrid
160+
assert fig.layout.yaxis3.showgrid

0 commit comments

Comments
 (0)