Skip to content

Commit ae14231

Browse files
PX shouldn't modify attrs controlled by template
1 parent 06a2cb9 commit ae14231

File tree

1 file changed

+37
-39
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+37
-39
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+37-39
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders):
375375

376376
# Configure axis ticks on marginal subplots
377377
if args["marginal_x"]:
378-
fig.update_yaxes(
379-
showticklabels=False, showgrid=args["marginal_x"] == "histogram", row=nrows
380-
)
381-
fig.update_xaxes(showgrid=True, row=nrows)
378+
fig.update_yaxes(showticklabels=False, row=nrows)
379+
if args["template"].layout.yaxis.showgrid is None:
380+
fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
381+
if args["template"].layout.xaxis.showgrid is None:
382+
fig.update_xaxes(showgrid=True, row=nrows)
382383

383384
if args["marginal_y"]:
384-
fig.update_xaxes(
385-
showticklabels=False, showgrid=args["marginal_y"] == "histogram", col=ncols
386-
)
387-
fig.update_yaxes(showgrid=True, col=ncols)
385+
fig.update_xaxes(showticklabels=False, col=ncols)
386+
if args["template"].layout.xaxis.showgrid is None:
387+
fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
388+
if args["template"].layout.yaxis.showgrid is None:
389+
fig.update_yaxes(showgrid=True, col=ncols)
388390

389391
# Add axis titles to non-marginal subplots
390392
y_title = get_decorated_label(args, args["y"], "y")
@@ -687,55 +689,47 @@ def apply_default_cascade(args):
687689
else:
688690
args["template"] = "plotly"
689691

690-
# retrieve the actual template if we were given a name
691692
try:
692-
template = pio.templates[args["template"]]
693+
# retrieve the actual template if we were given a name
694+
args["template"] = pio.templates[args["template"]]
693695
except Exception:
694-
template = args["template"]
696+
# otherwise try to build a real template
697+
args["template"] = go.layout.Template(args["template"])
695698

696699
# if colors not set explicitly or in px.defaults, defer to a template
697700
# if the template doesn't have one, we set some final fallback defaults
698701
if "color_continuous_scale" in args:
699-
if args["color_continuous_scale"] is None:
700-
try:
701-
args["color_continuous_scale"] = [
702-
x[1] for x in template.layout.colorscale.sequential
703-
]
704-
except (AttributeError, TypeError):
705-
pass
702+
if (
703+
args["color_continuous_scale"] is None
704+
and args["template"].layout.colorscale.sequential
705+
):
706+
args["color_continuous_scale"] = [
707+
x[1] for x in args["template"].layout.colorscale.sequential
708+
]
706709
if args["color_continuous_scale"] is None:
707710
args["color_continuous_scale"] = sequential.Viridis
708711

709712
if "color_discrete_sequence" in args:
710-
if args["color_discrete_sequence"] is None:
711-
try:
712-
args["color_discrete_sequence"] = template.layout.colorway
713-
except (AttributeError, TypeError):
714-
pass
713+
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
714+
args["color_discrete_sequence"] = args["template"].layout.colorway
715715
if args["color_discrete_sequence"] is None:
716716
args["color_discrete_sequence"] = qualitative.D3
717717

718718
# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
719719
# see if we can defer to template. If not, set reasonable defaults
720720
if "symbol_sequence" in args:
721-
if args["symbol_sequence"] is None:
722-
try:
723-
args["symbol_sequence"] = [
724-
scatter.marker.symbol for scatter in template.data.scatter
725-
]
726-
except (AttributeError, TypeError):
727-
pass
721+
if args["symbol_sequence"] is None and args["template"].data.scatter:
722+
args["symbol_sequence"] = [
723+
scatter.marker.symbol for scatter in args["template"].data.scatter
724+
]
728725
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
729726
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]
730727

731728
if "line_dash_sequence" in args:
732-
if args["line_dash_sequence"] is None:
733-
try:
734-
args["line_dash_sequence"] = [
735-
scatter.line.dash for scatter in template.data.scatter
736-
]
737-
except (AttributeError, TypeError):
738-
pass
729+
if args["line_dash_sequence"] is None and args["template"].data.scatter:
730+
args["line_dash_sequence"] = [
731+
scatter.line.dash for scatter in args["template"].data.scatter
732+
]
739733
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
740734
args["line_dash_sequence"] = [
741735
"solid",
@@ -1268,9 +1262,13 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12681262
if args[v]:
12691263
layout_patch[v] = args[v]
12701264
layout_patch["legend"] = {"tracegroupgap": 0}
1271-
if "title" not in layout_patch:
1265+
if "title" not in layout_patch and args["template"].layout.margin.t is None:
12721266
layout_patch["margin"] = {"t": 60}
1273-
if "size" in args and args["size"]:
1267+
if (
1268+
"size" in args
1269+
and args["size"]
1270+
and args["template"].layout.legend.itemsizing is None
1271+
):
12741272
layout_patch["legend"]["itemsizing"] = "constant"
12751273

12761274
fig = init_figure(

0 commit comments

Comments
 (0)